mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Complete SD pipeline.
This commit is contained in:
@@ -4,8 +4,10 @@ import time
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from pathlib import Path
|
||||
from random import randint
|
||||
from turbine_models.custom_models.sd_inference import clip, unet, vae
|
||||
from apps.shark_studio.api.controlnet import control_adapter_map
|
||||
from apps.shark_studio.web.utils.state import status_label
|
||||
@@ -16,6 +18,8 @@ from apps.shark_studio.modules.prompt_encoding import get_weighted_text_embeddin
|
||||
from apps.shark_studio.modules.img_processing import (
|
||||
resize_stencil,
|
||||
save_output_img,
|
||||
resamplers,
|
||||
resampler_list,
|
||||
)
|
||||
|
||||
from apps.shark_studio.modules.ckpt_processing import (
|
||||
@@ -30,8 +34,7 @@ sd_model_map = {
|
||||
"clip": {
|
||||
"initializer": clip.export_clip_model,
|
||||
"external_weight_file": None,
|
||||
"ireec_flags": ["--iree-flow-collapse-reduction-dims",
|
||||
],
|
||||
"ireec_flags": ["--iree-flow-collapse-reduction-dims"],
|
||||
},
|
||||
"vae_encode": {
|
||||
"initializer": vae.export_vae_model,
|
||||
@@ -48,6 +51,10 @@ sd_model_map = {
|
||||
"vae_decode": {
|
||||
"initializer": vae.export_vae_model,
|
||||
"external_weight_file": None,
|
||||
"ireec_flags": ["--iree-flow-collapse-reduction-dims",
|
||||
"--iree-opt-const-expr-hoisting=False",
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
@@ -78,15 +85,15 @@ class StableDiffusion(SharkPipelineBase):
|
||||
custom_vae: str = None,
|
||||
num_loras: int = 0,
|
||||
import_ir: bool = True,
|
||||
is_img2img: bool = False,
|
||||
is_controlled: bool = False,
|
||||
):
|
||||
self.model_max_length = 77
|
||||
self.batch_size = batch_size
|
||||
self.precision = precision
|
||||
self.is_img2img = is_img2img
|
||||
self.dtype = torch.float16 if precision == "fp16" else torch.float32
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.scheduler_obj = {}
|
||||
self.precision = precision
|
||||
static_kwargs = {
|
||||
"pipe": {},
|
||||
"clip": {"hf_model_name": base_model_id},
|
||||
@@ -98,6 +105,8 @@ class StableDiffusion(SharkPipelineBase):
|
||||
#"num_loras": num_loras,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"precision": precision,
|
||||
"max_length": 77 * 8,
|
||||
},
|
||||
"vae_encode": {
|
||||
"hf_model_name": custom_vae if custom_vae else base_model_id,
|
||||
@@ -105,13 +114,15 @@ class StableDiffusion(SharkPipelineBase):
|
||||
"batch_size": batch_size,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"precision": precision,
|
||||
},
|
||||
"vae_decode": {
|
||||
"hf_model_name": custom_vae,
|
||||
"hf_model_name": custom_vae if custom_vae else base_model_id,
|
||||
"vae_model": vae.VaeModel(hf_model_name=base_model_id, hf_auth_token=None),
|
||||
"batch_size": batch_size,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"precision": precision,
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
@@ -135,26 +146,26 @@ class StableDiffusion(SharkPipelineBase):
|
||||
gc.collect()
|
||||
|
||||
|
||||
def prepare_pipe(self, scheduler, custom_weights, adapters, embeddings):
|
||||
def prepare_pipe(self, scheduler, custom_weights, adapters, embeddings, is_img2img):
|
||||
print(
|
||||
f"\n[LOG] Preparing pipeline with scheduler {scheduler}"
|
||||
f"\n[LOG] Custom embeddings currently unsupported."
|
||||
)
|
||||
self.is_img2img = is_img2img
|
||||
schedulers = get_schedulers(self.base_model_id)
|
||||
self.weights_path = get_checkpoints_path(self.pipe_id)
|
||||
self.weights_path = get_checkpoints_path(self.safe_name(self.pipe_id))
|
||||
if not os.path.exists(self.weights_path):
|
||||
os.mkdir(self.weights_path)
|
||||
# accepting a list of schedulers in batched cases.
|
||||
for i in scheduler:
|
||||
self.scheduler_obj[i] = schedulers[i]
|
||||
print(f"[LOG] Loaded scheduler: {i}")
|
||||
self.scheduler = schedulers[scheduler]
|
||||
print(f"[LOG] Loaded scheduler: {scheduler}")
|
||||
for model in adapters:
|
||||
self.model_map[model] = adapters[model]
|
||||
if os.path.isfile(custom_weights):
|
||||
for i in self.model_map:
|
||||
self.model_map[i]["external_weights_file"] = None
|
||||
elif custom_weights != "":
|
||||
print(f"\n[LOG][WARNING] Custom weights were not found at {custom_weights}. Did you mean to pass a base model ID?")
|
||||
if custom_weights:
|
||||
if os.path.isfile(custom_weights):
|
||||
for i in self.model_map:
|
||||
self.model_map[i]["external_weights_file"] = None
|
||||
elif custom_weights:
|
||||
print(f"\n[LOG][WARNING] Custom weights were not found at {custom_weights}. Did you mean to pass a base model ID?")
|
||||
self.static_kwargs["pipe"] = {
|
||||
# "external_weight_path": self.weights_path,
|
||||
# "external_weights": "safetensors",
|
||||
@@ -162,6 +173,92 @@ class StableDiffusion(SharkPipelineBase):
|
||||
self.get_compiled_map(pipe_id=self.pipe_id)
|
||||
print("\n[LOG] Pipeline successfully prepared for runtime.")
|
||||
return
|
||||
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
seed,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
use_base_vae,
|
||||
resample_type,
|
||||
control_mode,
|
||||
hints,
|
||||
):
|
||||
#TODO: Batched args
|
||||
self.ondemand = ondemand
|
||||
if self.is_img2img:
|
||||
image, _ = self.process_sd_init_image(image, resample_type)
|
||||
else:
|
||||
image = None
|
||||
|
||||
print("\n[LOG] Generating images...")
|
||||
batched_args=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
#steps,
|
||||
#strength,
|
||||
#guidance_scale,
|
||||
#seed,
|
||||
#resample_type,
|
||||
#control_mode,
|
||||
#hints,
|
||||
]
|
||||
for arg in batched_args:
|
||||
if not isinstance(arg, list):
|
||||
arg = [arg] * self.batch_size
|
||||
if len(arg) < self.batch_size:
|
||||
arg = arg * self.batch_size
|
||||
else:
|
||||
arg = [arg[i] for i in range(self.batch_size)]
|
||||
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
)
|
||||
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
init_latents, final_timesteps = self.prepare_latents(
|
||||
generator=generator,
|
||||
num_inference_steps=steps,
|
||||
image=image,
|
||||
strength=strength,
|
||||
)
|
||||
|
||||
latents = self.produce_img_latents(
|
||||
latents=init_latents,
|
||||
text_embeddings=text_embeddings,
|
||||
guidance_scale=guidance_scale,
|
||||
total_timesteps=final_timesteps,
|
||||
cpu_scheduling=True, # until we have schedulers through Turbine
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
all_imgs = []
|
||||
self.load_submodels(["vae_decode"])
|
||||
for i in tqdm(range(0, latents.shape[0], self.batch_size)):
|
||||
imgs = self.decode_latents(
|
||||
latents=latents[i : i + self.batch_size],
|
||||
use_base_vae=use_base_vae,
|
||||
cpu_scheduling=True,
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
if self.ondemand:
|
||||
self.unload_submodels(["vae_decode"])
|
||||
|
||||
return all_imgs
|
||||
|
||||
|
||||
def encode_prompts_weight(
|
||||
@@ -191,84 +288,220 @@ class StableDiffusion(SharkPipelineBase):
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
pad = (0, 0) * (len(text_embeddings.shape) - 2)
|
||||
pad = pad + (0, 512 - text_embeddings.shape[1])
|
||||
pad = pad + (0, 77 * 8 - text_embeddings.shape[1])
|
||||
text_embeddings = torch.nn.functional.pad(text_embeddings, pad)
|
||||
|
||||
# SHARK: Report clip inference time
|
||||
clip_inf_time = (time.time() - clip_inf_start) * 1000
|
||||
if self.ondemand:
|
||||
self.unload_clip()
|
||||
self.unload_submodels(["clip"])
|
||||
gc.collect()
|
||||
print(f"\n[LOG] Clip Inference time (ms) = {clip_inf_time:.3f}")
|
||||
|
||||
return text_embeddings.numpy().astype(np.float16)
|
||||
|
||||
|
||||
def generate_images(
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
image,
|
||||
strength,
|
||||
guidance_scale,
|
||||
seed,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
control_mode,
|
||||
hints,
|
||||
):
|
||||
print("\n[LOG] Generating images...")
|
||||
batched_args=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
seed,
|
||||
resample_type,
|
||||
control_mode,
|
||||
hints,
|
||||
]
|
||||
for arg in batched_args:
|
||||
if not isinstance(arg, list):
|
||||
arg = [arg] * self.batch_size
|
||||
if len(arg) < self.batch_size:
|
||||
arg = arg * self.batch_size
|
||||
else:
|
||||
arg = [arg[i] for i in range(self.batch_size)]
|
||||
noise = torch.randn(
|
||||
(
|
||||
self.batch_size,
|
||||
4,
|
||||
self.height // 8,
|
||||
self.width // 8,
|
||||
),
|
||||
generator=generator,
|
||||
dtype=self.dtype,
|
||||
).to("cpu")
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
if self.is_img2img:
|
||||
init_timestep = min(
|
||||
int(num_inference_steps * strength), num_inference_steps
|
||||
)
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
latents = self.encode_image(image)
|
||||
latents = self.scheduler.add_noise(
|
||||
latents, noise, timesteps[0].repeat(1)
|
||||
)
|
||||
return latents, [timesteps]
|
||||
else:
|
||||
self.scheduler.is_scale_input_called = True
|
||||
latents = noise * self.scheduler.init_noise_sigma
|
||||
return latents, self.scheduler.timesteps
|
||||
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
)
|
||||
print(text_embeddings)
|
||||
test_img = [
|
||||
Image.open(
|
||||
get_resource_path("../../tests/jupiter.png"), mode="r"
|
||||
).convert("RGB")
|
||||
] * self.batch_size
|
||||
return test_img
|
||||
|
||||
def encode_image(self, input_image):
|
||||
self.load_submodels(["vae_encode"])
|
||||
vae_encode_start = time.time()
|
||||
latents = self.run("vae_encode", input_image)
|
||||
vae_inf_time = (time.time() - vae_encode_start) * 1000
|
||||
if self.ondemand:
|
||||
self.unload_submodels(["vae_encode"])
|
||||
print(f"\n[LOG] VAE Encode Inference time (ms): {vae_inf_time:.3f}")
|
||||
|
||||
return latents
|
||||
|
||||
|
||||
def produce_img_latents(
|
||||
self,
|
||||
latents,
|
||||
text_embeddings,
|
||||
guidance_scale,
|
||||
total_timesteps,
|
||||
cpu_scheduling,
|
||||
mask=None,
|
||||
masked_image_latents=None,
|
||||
return_all_latents=False,
|
||||
):
|
||||
# self.status = SD_STATE_IDLE
|
||||
step_time_sum = 0
|
||||
latent_history = [latents]
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(torch.float16)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
guidance_scale = np.asarray([guidance_scale], dtype=np.float16)
|
||||
self.load_submodels(["unet"])
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
timestep = torch.tensor([t]).to(self.dtype).detach().numpy()
|
||||
latent_model_input = self.scheduler.scale_model_input(latents, t).to(self.dtype)
|
||||
if mask is not None and masked_image_latents is not None:
|
||||
latent_model_input = torch.cat(
|
||||
[
|
||||
torch.from_numpy(np.asarray(latent_model_input)).to(torch.float16),
|
||||
mask,
|
||||
masked_image_latents,
|
||||
],
|
||||
dim=1,
|
||||
).to(self.dtype)
|
||||
if cpu_scheduling:
|
||||
latent_model_input = latent_model_input.detach().numpy()
|
||||
|
||||
# Profiling Unet.
|
||||
# profile_device = start_profiling(file_path="unet.rdc")
|
||||
noise_pred = self.run(
|
||||
"unet",
|
||||
[
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
],
|
||||
)
|
||||
# end_profiling(profile_device)
|
||||
|
||||
if cpu_scheduling:
|
||||
noise_pred = torch.from_numpy(noise_pred.to_host())
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents
|
||||
).prev_sample
|
||||
else:
|
||||
latents = self.run("scheduler_step", (noise_pred, t, latents))
|
||||
|
||||
latent_history.append(latents)
|
||||
step_time = (time.time() - step_start_time) * 1000
|
||||
# self.log += (
|
||||
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
|
||||
# )
|
||||
step_time_sum += step_time
|
||||
|
||||
#if self.status == SD_STATE_CANCEL:
|
||||
# break
|
||||
|
||||
if self.ondemand:
|
||||
self.unload_submodels(["unet"])
|
||||
gc.collect()
|
||||
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
print(f"\n[LOG] Average step time: {avg_step_time}ms/it")
|
||||
|
||||
if not return_all_latents:
|
||||
return latents
|
||||
all_latents = torch.cat(latent_history, dim=0)
|
||||
return all_latents
|
||||
|
||||
|
||||
def decode_latents(self, latents, use_base_vae, cpu_scheduling):
|
||||
if use_base_vae:
|
||||
latents = 1 / 0.18215 * latents
|
||||
|
||||
latents_numpy = latents.to(self.dtype)
|
||||
if cpu_scheduling:
|
||||
latents_numpy = latents.detach().numpy()
|
||||
|
||||
#profile_device = start_profiling(file_path="vae.rdc")
|
||||
vae_start = time.time()
|
||||
images = self.run("vae_decode", latents_numpy).to_host()
|
||||
vae_inf_time = (time.time() - vae_start) * 1000
|
||||
#end_profiling(profile_device)
|
||||
print(f"\n[LOG] VAE Inference time (ms): {vae_inf_time:.3f}")
|
||||
|
||||
if use_base_vae:
|
||||
images = torch.from_numpy(images)
|
||||
images = (images.detach().cpu() * 255.0).numpy()
|
||||
images = images.round()
|
||||
|
||||
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
|
||||
pil_images = [Image.fromarray(image) for image in images.numpy()]
|
||||
return pil_images
|
||||
|
||||
|
||||
def process_sd_init_image(self, sd_init_image, resample_type):
|
||||
if isinstance(sd_init_image, list):
|
||||
images = []
|
||||
for img in sd_init_image:
|
||||
img, _ = self.process_sd_init_image(img, resample_type)
|
||||
images.append(img)
|
||||
is_img2img = True
|
||||
return images, is_img2img
|
||||
if isinstance(sd_init_image, str):
|
||||
if os.path.isfile(sd_init_image):
|
||||
sd_init_image = Image.open(sd_init_image, mode="r").convert("RGB")
|
||||
image, is_img2img = self.process_sd_init_image(sd_init_image, resample_type)
|
||||
else:
|
||||
image = None
|
||||
is_img2img = False
|
||||
elif isinstance(sd_init_image, Image.Image):
|
||||
image = sd_init_image.convert("RGB")
|
||||
elif sd_init_image:
|
||||
image = sd_init_image["image"].convert("RGB")
|
||||
else:
|
||||
image = None
|
||||
is_img2img = False
|
||||
if image:
|
||||
resample_type = (
|
||||
resamplers[resample_type]
|
||||
if resample_type in resampler_list
|
||||
# Fallback to Lanczos
|
||||
else Image.Resampling.LANCZOS
|
||||
)
|
||||
image = image.resize((self.width, self.height), resample=resample_type)
|
||||
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
|
||||
image_arr = image_arr / 255.0
|
||||
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(self.dtype)
|
||||
image_arr = 2 * (image_arr - 0.5)
|
||||
is_img2img = True
|
||||
image = image_arr
|
||||
return image, is_img2img
|
||||
|
||||
|
||||
def shark_sd_fn_dict_input(
|
||||
sd_kwargs: dict,
|
||||
):
|
||||
print("[LOG] Submitting Request...")
|
||||
input_imgs = []
|
||||
img_paths = sd_kwargs["sd_init_image"]
|
||||
|
||||
for img_path in img_paths:
|
||||
if img_path:
|
||||
if os.path.isfile(img_path):
|
||||
input_imgs.append(
|
||||
Image.open(img_path, mode="r").convert("RGB")
|
||||
)
|
||||
sd_kwargs["sd_init_image"] = input_imgs
|
||||
# result = shark_sd_fn(**sd_kwargs)
|
||||
# for i in range(sd_kwargs["batch_count"]):
|
||||
# yield from result
|
||||
# return result
|
||||
for key in sd_kwargs:
|
||||
if sd_kwargs[key] in ["None", "", None, []]:
|
||||
sd_kwargs[key] = None
|
||||
if key == "seed":
|
||||
sd_kwargs[key] = int(sd_kwargs[key])
|
||||
|
||||
for i in range(1):
|
||||
generated_imgs = yield from shark_sd_fn(**sd_kwargs)
|
||||
yield generated_imgs
|
||||
@@ -278,7 +511,7 @@ def shark_sd_fn_dict_input(
|
||||
def shark_sd_fn(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
sd_init_image,
|
||||
sd_init_image: list,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
@@ -291,6 +524,7 @@ def shark_sd_fn(
|
||||
base_model_id: str,
|
||||
custom_weights: str,
|
||||
custom_vae: str,
|
||||
use_base_vae: bool,
|
||||
precision: str,
|
||||
device: str,
|
||||
ondemand: bool,
|
||||
@@ -300,20 +534,9 @@ def shark_sd_fn(
|
||||
embeddings: dict,
|
||||
):
|
||||
sd_kwargs = locals()
|
||||
if isinstance(sd_init_image, Image.Image):
|
||||
image = sd_init_image.convert("RGB")
|
||||
elif sd_init_image:
|
||||
image = sd_init_image["image"].convert("RGB")
|
||||
else:
|
||||
image = None
|
||||
is_img2img = False
|
||||
if image:
|
||||
(
|
||||
image,
|
||||
_,
|
||||
_,
|
||||
) = resize_stencil(image, width, height)
|
||||
is_img2img = True
|
||||
is_img2img = True if sd_init_image[0] is not None else False
|
||||
|
||||
|
||||
print("\n[LOG] Performing Stable Diffusion Pipeline setup...")
|
||||
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
@@ -358,7 +581,6 @@ def shark_sd_fn(
|
||||
"custom_vae": custom_vae,
|
||||
"num_loras": num_loras,
|
||||
"import_ir": cmd_opts.import_mlir,
|
||||
"is_img2img": is_img2img,
|
||||
"is_controlled": is_controlled,
|
||||
}
|
||||
submit_prep_kwargs = {
|
||||
@@ -366,16 +588,19 @@ def shark_sd_fn(
|
||||
"custom_weights": custom_weights,
|
||||
"adapters": adapters,
|
||||
"embeddings": embeddings,
|
||||
"is_img2img": is_img2img,
|
||||
}
|
||||
submit_run_kwargs = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"image": sd_init_image,
|
||||
"steps": steps,
|
||||
"strength": strength,
|
||||
"guidance_scale": guidance_scale,
|
||||
"seed": seed,
|
||||
"ondemand": ondemand,
|
||||
"repeatable_seeds": repeatable_seeds,
|
||||
"use_base_vae": use_base_vae,
|
||||
"resample_type": resample_type,
|
||||
"control_mode": control_mode,
|
||||
"hints": hints,
|
||||
@@ -410,13 +635,9 @@ def shark_sd_fn(
|
||||
# if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
# break
|
||||
# else:
|
||||
try:
|
||||
this_seed = seed[current_batch]
|
||||
except:
|
||||
this_seed = seed[0]
|
||||
save_output_img(
|
||||
out_imgs[0],
|
||||
this_seed,
|
||||
seed,
|
||||
sd_kwargs,
|
||||
)
|
||||
generated_imgs.extend(out_imgs)
|
||||
@@ -438,6 +659,7 @@ def view_json_file(file_path):
|
||||
return content
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
|
||||
@@ -75,9 +75,9 @@ def save_output_img(output_img, img_seed, extra_info=None):
|
||||
"parameters",
|
||||
f"{extra_info['prompt'][0]}"
|
||||
f"\nNegative prompt: {extra_info['negative_prompt'][0]}"
|
||||
f"\nSteps: {extra_info['steps'][0]},"
|
||||
f"Sampler: {extra_info['scheduler'][0]}, "
|
||||
f"CFG scale: {extra_info['guidance_scale'][0]}, "
|
||||
f"\nSteps: {extra_info['steps']},"
|
||||
f"Sampler: {extra_info['scheduler']}, "
|
||||
f"CFG scale: {extra_info['guidance_scale']}, "
|
||||
f"Seed: {img_seed},"
|
||||
f"Size: {png_size_text}, "
|
||||
f"Model: {img_model}, "
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
from msvcrt import kbhit
|
||||
from shark.iree_utils.compile_utils import get_iree_compiled_module, load_vmfb_using_mmap
|
||||
from shark.iree_utils.compile_utils import (
|
||||
get_iree_compiled_module,
|
||||
load_vmfb_using_mmap,
|
||||
clean_device_info,
|
||||
get_iree_target_triple,
|
||||
)
|
||||
from apps.shark_studio.web.utils.file_utils import (
|
||||
get_checkpoints_path,
|
||||
get_resource_path,
|
||||
@@ -32,8 +37,8 @@ class SharkPipelineBase:
|
||||
self.model_map = model_map
|
||||
self.static_kwargs = static_kwargs
|
||||
self.base_model_id = base_model_id
|
||||
self.device_name = device
|
||||
self.device = device.split("=>")[-1].strip(" ")
|
||||
self.triple = get_iree_target_triple(device)
|
||||
self.device, self.device_id = clean_device_info(device)
|
||||
self.import_mlir = import_mlir
|
||||
self.iree_module_dict = {}
|
||||
self.tempfiles = {}
|
||||
@@ -46,11 +51,11 @@ class SharkPipelineBase:
|
||||
# initialization. As soon as you have a pipeline ID unique to your static torch IR parameters,
|
||||
# and your model map is populated with any IR - unique model IDs and their static params,
|
||||
# call this method to get the artifacts associated with your map.
|
||||
self.pipe_id = pipe_id
|
||||
self.pipe_id = self.safe_name(pipe_id)
|
||||
self.pipe_vmfb_path = Path(os.path.join(get_checkpoints_path(".."), self.pipe_id))
|
||||
self.pipe_vmfb_path.mkdir(parents=True, exist_ok=True)
|
||||
print("\n[LOG] Checking for pre-compiled artifacts.")
|
||||
if submodel == "None":
|
||||
print("\n[LOG] Gathering any pre-compiled artifacts....")
|
||||
for key in self.model_map:
|
||||
self.get_compiled_map(pipe_id, submodel=key)
|
||||
else:
|
||||
@@ -58,10 +63,12 @@ class SharkPipelineBase:
|
||||
ireec_flags = []
|
||||
if submodel in self.iree_module_dict:
|
||||
if "vmfb" in self.iree_module_dict[submodel]:
|
||||
print(f"[LOG] Found executable for {submodel} at {self.iree_module_dict[submodel]['vmfb']}...")
|
||||
print(f"\n[LOG] Executable for {submodel} already loaded...")
|
||||
return
|
||||
elif "vmfb_path" in self.model_map[submodel]:
|
||||
return
|
||||
elif submodel not in self.tempfiles:
|
||||
print(f"[LOG] Tempfile for {submodel} not found. Fetching torch IR...")
|
||||
print(f"\n[LOG] Tempfile for {submodel} not found. Fetching torch IR...")
|
||||
if submodel in self.static_kwargs:
|
||||
init_kwargs = self.static_kwargs[submodel]
|
||||
for key in self.static_kwargs["pipe"]:
|
||||
@@ -90,16 +97,6 @@ class SharkPipelineBase:
|
||||
return
|
||||
|
||||
|
||||
def hijack_weights(self, weights_path, submodel="None"):
|
||||
if submodel == "None":
|
||||
for i in self.model_map:
|
||||
self.hijack_weights(weights_path, i)
|
||||
else:
|
||||
if submodel in self.iree_module_dict:
|
||||
self.model_map[submodel]["external_weights_file"] = weights_path
|
||||
return
|
||||
|
||||
|
||||
def get_precompiled(self, pipe_id, submodel="None"):
|
||||
if submodel == "None":
|
||||
for model in self.model_map:
|
||||
@@ -112,33 +109,10 @@ class SharkPipelineBase:
|
||||
break
|
||||
for file in vmfbs:
|
||||
if submodel in file:
|
||||
print(f"Found existing .vmfb at {file}")
|
||||
self.iree_module_dict[submodel] = {}
|
||||
(
|
||||
self.iree_module_dict[submodel]["vmfb"],
|
||||
self.iree_module_dict[submodel]["config"],
|
||||
self.iree_module_dict[submodel]["temp_file_to_unlink"],
|
||||
) = load_vmfb_using_mmap(
|
||||
os.path.join(vmfbs_path, file),
|
||||
self.device,
|
||||
device_idx=0,
|
||||
rt_flags=[],
|
||||
external_weight_file=self.model_map[submodel]['external_weight_file'],
|
||||
)
|
||||
self.model_map[submodel]["vmfb_path"] = os.path.join(vmfbs_path, file)
|
||||
return
|
||||
|
||||
|
||||
def safe_dict(self, kwargs: dict):
|
||||
flat_args = {}
|
||||
for i in kwargs:
|
||||
if isinstance(kwargs[i], dict) and "pass_dict" not in kwargs[i]:
|
||||
flat_args[i] = [kwargs[i][j] for j in kwargs[i]]
|
||||
else:
|
||||
flat_args[i] = kwargs[i]
|
||||
|
||||
return flat_args
|
||||
|
||||
|
||||
def import_torch_ir(self, submodel, kwargs):
|
||||
torch_ir = self.model_map[submodel]["initializer"](
|
||||
**self.safe_dict(kwargs), compile_to="torch"
|
||||
@@ -160,18 +134,53 @@ class SharkPipelineBase:
|
||||
def load_submodels(self, submodels: list):
|
||||
for submodel in submodels:
|
||||
if submodel in self.iree_module_dict:
|
||||
print(f"\n[LOG] {submodel} is ready for inference.")
|
||||
if "vmfb_path" in self.model_map[submodel]:
|
||||
print(
|
||||
f"\n[LOG] Loading .vmfb for {submodel} from {self.iree_module_dict[submodel]['vmfb']}"
|
||||
f"\n[LOG] Loading .vmfb for {submodel} from {self.model_map[submodel]['vmfb_path']}"
|
||||
)
|
||||
self.iree_module_dict[submodel] = {}
|
||||
(
|
||||
self.iree_module_dict[submodel]["vmfb"],
|
||||
self.iree_module_dict[submodel]["config"],
|
||||
self.iree_module_dict[submodel]["temp_file_to_unlink"],
|
||||
) = load_vmfb_using_mmap(
|
||||
self.model_map[submodel]["vmfb_path"],
|
||||
self.device,
|
||||
device_idx=0,
|
||||
rt_flags=[],
|
||||
external_weight_file=self.model_map[submodel]['external_weight_file'],
|
||||
)
|
||||
else:
|
||||
self.get_compiled_map(self.pipe_id, submodel)
|
||||
return
|
||||
|
||||
|
||||
def unload_submodels(self, submodels: list):
|
||||
for submodel in submodels:
|
||||
if submodel in self.iree_module_dict:
|
||||
del self.iree_module_dict[submodel]
|
||||
gc.collect()
|
||||
return
|
||||
|
||||
|
||||
def run(self, submodel, inputs):
|
||||
inp = [ireert.asdevicearray(self.iree_module_dict[submodel]["config"].device, inputs)]
|
||||
if not isinstance(inputs, list):
|
||||
inputs = [inputs]
|
||||
inp = [ireert.asdevicearray(self.iree_module_dict[submodel]["config"].device, input) for input in inputs]
|
||||
return self.iree_module_dict[submodel]['vmfb']['main'](*inp)
|
||||
|
||||
|
||||
def safe_name(name):
|
||||
return name.replace("/", "_").replace("-", "_")
|
||||
def safe_name(self, name):
|
||||
return name.replace("/", "_").replace("-", "_").replace("\\", "_")
|
||||
|
||||
|
||||
def safe_dict(self, kwargs: dict):
|
||||
flat_args = {}
|
||||
for i in kwargs:
|
||||
if isinstance(kwargs[i], dict) and "pass_dict" not in kwargs[i]:
|
||||
flat_args[i] = [kwargs[i][j] for j in kwargs[i]]
|
||||
else:
|
||||
flat_args[i] = kwargs[i]
|
||||
|
||||
return flat_args
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import List, Optional, Union
|
||||
from iree import runtime as ireert
|
||||
import re
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
re_attention = re.compile(
|
||||
r"""
|
||||
@@ -161,7 +162,7 @@ def pad_tokens_and_weights(
|
||||
r"""
|
||||
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
||||
"""
|
||||
max_embeddings_multiples = 8
|
||||
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
||||
weights_length = (
|
||||
max_length
|
||||
if no_boseos_middle
|
||||
@@ -194,13 +195,16 @@ def pad_tokens_and_weights(
|
||||
|
||||
return tokens, weights
|
||||
|
||||
|
||||
def get_unweighted_text_embeddings(
|
||||
pipe,
|
||||
text_input: torch.Tensor,
|
||||
text_input,
|
||||
chunk_length: int,
|
||||
no_boseos_middle: Optional[bool] = True,
|
||||
):
|
||||
"""
|
||||
When the length of tokens is a multiple of the capacity of the text encoder,
|
||||
it should be split into chunks and sent to the text encoder individually.
|
||||
"""
|
||||
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
||||
if max_embeddings_multiples > 1:
|
||||
text_embeddings = []
|
||||
@@ -214,7 +218,7 @@ def get_unweighted_text_embeddings(
|
||||
text_input_chunk[:, 0] = text_input[0, 0]
|
||||
text_input_chunk[:, -1] = text_input[0, -1]
|
||||
|
||||
text_embedding = pipe.run("clip", text_input_chunk)[0]
|
||||
text_embedding = pipe.run("clip", text_input_chunk)[0].to_host()
|
||||
|
||||
if no_boseos_middle:
|
||||
if i == 0:
|
||||
@@ -231,50 +235,14 @@ def get_unweighted_text_embeddings(
|
||||
# SHARK: Convert the result to tensor
|
||||
# text_embeddings = torch.concat(text_embeddings, axis=1)
|
||||
text_embeddings_np = np.concatenate(np.array(text_embeddings))
|
||||
text_embeddings = torch.from_numpy(text_embeddings_np)[None, :]
|
||||
text_embeddings = torch.from_numpy(text_embeddings_np)
|
||||
else:
|
||||
text_embeddings = pipe.run("clip", text_input)[0]
|
||||
# text_embeddings = torch.from_numpy(text_embeddings)[None, :]
|
||||
return torch.from_numpy(text_embeddings.to_host())
|
||||
"""
|
||||
When the length of tokens is a multiple of the capacity of the text encoder,
|
||||
it should be split into chunks and sent to the text encoder individually.
|
||||
"""
|
||||
max_embeddings_multiples = 8
|
||||
text_embeddings = []
|
||||
for i in range(max_embeddings_multiples):
|
||||
# extract the i-th chunk
|
||||
text_input_chunk = text_input[
|
||||
:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2
|
||||
].clone()
|
||||
|
||||
# cover the head and the tail by the starting and the ending tokens
|
||||
text_input_chunk[:, 0] = text_input[0, 0]
|
||||
text_input_chunk[:, -1] = text_input[0, -1]
|
||||
# text_embedding = pipe.text_encoder(text_input_chunk)[0]
|
||||
|
||||
print(text_input_chunk)
|
||||
breakpoint()
|
||||
text_embedding = pipe.run("clip", text_input_chunk)
|
||||
if no_boseos_middle:
|
||||
if i == 0:
|
||||
# discard the ending token
|
||||
text_embedding = text_embedding[:, :-1]
|
||||
elif i == max_embeddings_multiples - 1:
|
||||
# discard the starting token
|
||||
text_embedding = text_embedding[:, 1:]
|
||||
else:
|
||||
# discard both starting and ending tokens
|
||||
text_embedding = text_embedding[:, 1:-1]
|
||||
|
||||
text_embeddings.append(text_embedding)
|
||||
# SHARK: Convert the result to tensor
|
||||
# text_embeddings = torch.concat(text_embeddings, axis=1)
|
||||
text_embeddings_np = np.concatenate(np.array(text_embeddings))
|
||||
text_embeddings = torch.from_numpy(text_embeddings_np)[None, :]
|
||||
text_embeddings = torch.from_numpy(text_embeddings.to_host())
|
||||
return text_embeddings
|
||||
|
||||
|
||||
|
||||
# This function deals with NoneType values occuring in tokens after padding
|
||||
# It switches out None with 49407 as truncating None values causes matrix dimension errors,
|
||||
def filter_nonetype_tokens(tokens: List[List]):
|
||||
@@ -286,7 +254,7 @@ def get_weighted_text_embeddings(
|
||||
prompt: List[str],
|
||||
uncond_prompt: List[str] = None,
|
||||
max_embeddings_multiples: Optional[int] = 8,
|
||||
no_boseos_middle: Optional[bool] = False,
|
||||
no_boseos_middle: Optional[bool] = True,
|
||||
skip_parsing: Optional[bool] = False,
|
||||
skip_weighting: Optional[bool] = False,
|
||||
):
|
||||
@@ -325,12 +293,12 @@ def get_weighted_text_embeddings(
|
||||
max_length = max(
|
||||
max_length, max([len(token) for token in uncond_tokens])
|
||||
)
|
||||
|
||||
max_embeddings_multiples = min(
|
||||
max_embeddings_multiples,
|
||||
(max_length - 1) // (pipe.model_max_length - 2) + 1,
|
||||
)
|
||||
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
||||
|
||||
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
|
||||
|
||||
# pad the length of tokens and weights
|
||||
|
||||
@@ -18,7 +18,7 @@ from diffusers import (
|
||||
|
||||
def get_schedulers(model_id):
|
||||
#TODO: switch over to turbine and run all on GPU
|
||||
print(f"[LOG] Initializing schedulers from model id: {model_id}")
|
||||
print(f"\n[LOG] Initializing schedulers from model id: {model_id}")
|
||||
schedulers = dict()
|
||||
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
|
||||
model_id,
|
||||
|
||||
@@ -1,23 +1,24 @@
|
||||
{
|
||||
"prompt": [ "a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))" ],
|
||||
"negative_prompt": [ "watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped" ],
|
||||
"sd_init_image": [ "None" ],
|
||||
"sd_init_image": [ null ],
|
||||
"height": 512,
|
||||
"width": 512,
|
||||
"steps": [ 50 ],
|
||||
"strength": [ 0.8 ],
|
||||
"guidance_scale": [ 7.5 ],
|
||||
"seed": [ -1 ],
|
||||
"steps": 50,
|
||||
"strength": 0.8,
|
||||
"guidance_scale": 7.5,
|
||||
"seed": -1,
|
||||
"batch_count": 1,
|
||||
"batch_size": 1,
|
||||
"scheduler": [ "EulerDiscrete" ],
|
||||
"scheduler": "EulerDiscrete",
|
||||
"base_model_id": "runwayml/stable-diffusion-v1-5",
|
||||
"custom_weights": "",
|
||||
"custom_vae": "",
|
||||
"custom_weights": null,
|
||||
"custom_vae": null,
|
||||
"use_base_vae": false,
|
||||
"precision": "fp16",
|
||||
"device": "vulkan",
|
||||
"ondemand": "False",
|
||||
"repeatable_seeds": "False",
|
||||
"ondemand": false,
|
||||
"repeatable_seeds": false,
|
||||
"resample_type": "Nearest Neighbor",
|
||||
"controlnets": {},
|
||||
"embeddings": {}
|
||||
|
||||
@@ -41,6 +41,14 @@ from apps.shark_studio.web.ui.common_events import lora_changed
|
||||
from apps.shark_studio.modules import logger
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
|
||||
sd_default_models = [
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-xl-1.0",
|
||||
"stabilityai/sdxl-turbo",
|
||||
]
|
||||
|
||||
def view_json_file(file_path):
|
||||
content = ""
|
||||
@@ -105,6 +113,7 @@ def pull_sd_configs(
|
||||
base_model_id,
|
||||
custom_weights,
|
||||
custom_vae,
|
||||
use_base_vae,
|
||||
precision,
|
||||
device,
|
||||
ondemand,
|
||||
@@ -120,11 +129,6 @@ def pull_sd_configs(
|
||||
"prompt",
|
||||
"negative_prompt",
|
||||
"sd_init_image",
|
||||
"steps",
|
||||
"strength",
|
||||
"guidance_scale",
|
||||
"seed",
|
||||
"scheduler",
|
||||
]:
|
||||
sd_cfg[arg] = [sd_args[arg]]
|
||||
elif arg in ["controlnets", "embeddings"]:
|
||||
@@ -144,8 +148,10 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
|
||||
sd_json[key] = new_sd_config[key]
|
||||
else:
|
||||
sd_json = new_sd_config
|
||||
if os.path.isfile(sd_json["sd_init_image"][0]):
|
||||
sd_image = Image.open(sd_json["sd_init_image"][0], mode="r")
|
||||
for i in sd_json["sd_init_image"]:
|
||||
if i is not None:
|
||||
if os.path.isfile(i):
|
||||
sd_image = [Image.open(i, mode="r")]
|
||||
else:
|
||||
sd_image = None
|
||||
|
||||
@@ -155,16 +161,17 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
|
||||
sd_image,
|
||||
sd_json["height"],
|
||||
sd_json["width"],
|
||||
sd_json["steps"][0],
|
||||
sd_json["strength"][0],
|
||||
sd_json["steps"],
|
||||
sd_json["strength"],
|
||||
sd_json["guidance_scale"],
|
||||
sd_json["seed"][0],
|
||||
sd_json["seed"],
|
||||
sd_json["batch_count"],
|
||||
sd_json["batch_size"],
|
||||
sd_json["scheduler"][0],
|
||||
sd_json["scheduler"],
|
||||
sd_json["base_model_id"],
|
||||
sd_json["custom_weights"],
|
||||
sd_json["custom_vae"],
|
||||
sd_json["use_base_vae"],
|
||||
sd_json["precision"],
|
||||
sd_json["device"],
|
||||
sd_json["ondemand"],
|
||||
@@ -320,7 +327,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
info="Select or enter HF model ID",
|
||||
elem_id="custom_model",
|
||||
value="stabilityai/stable-diffusion-2-1-base",
|
||||
choices=sd_model_map.keys(),
|
||||
choices=sd_default_models,
|
||||
) # base_model_id
|
||||
custom_weights = gr.Dropdown(
|
||||
label="Custom Weights",
|
||||
@@ -328,7 +335,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
elem_id="custom_model",
|
||||
value="None",
|
||||
allow_custom_value=True,
|
||||
choices=get_checkpoints(base_model_id),
|
||||
choices=["None"] + get_checkpoints(base_model_id),
|
||||
) #
|
||||
with gr.Column(scale=2):
|
||||
sd_vae_info = (
|
||||
@@ -361,6 +368,11 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
use_base_vae = gr.Checkbox(
|
||||
value=False,
|
||||
label="Baked VAE",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
@@ -639,7 +651,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
with gr.Column(scale=3, min_width=600):
|
||||
with gr.Group():
|
||||
sd_gallery = gr.Gallery(
|
||||
label="Generated images",
|
||||
label="Generated images",
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
columns=2,
|
||||
@@ -719,6 +731,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
base_model_id,
|
||||
custom_weights,
|
||||
custom_vae,
|
||||
use_base_vae,
|
||||
precision,
|
||||
device,
|
||||
ondemand,
|
||||
@@ -753,6 +766,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
base_model_id,
|
||||
custom_weights,
|
||||
custom_vae,
|
||||
use_base_vae,
|
||||
precision,
|
||||
device,
|
||||
ondemand,
|
||||
|
||||
@@ -65,6 +65,14 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args)
|
||||
return []
|
||||
|
||||
def get_iree_target_triple(device):
|
||||
args = get_iree_device_args(device)
|
||||
for flag in args:
|
||||
if "triple" in flag.split("-"):
|
||||
triple = flag.split("=")
|
||||
return triple
|
||||
return ""
|
||||
|
||||
|
||||
def clean_device_info(raw_device):
|
||||
# return appropriate device and device_id for consumption by Studio pipeline
|
||||
|
||||
Reference in New Issue
Block a user