Fixes to defaults, file mgmt, VAE

This commit is contained in:
Ean Garvey
2023-12-19 00:36:58 -06:00
parent 3dc9ab3857
commit a06adc4eb2
5 changed files with 55 additions and 48 deletions

View File

@@ -40,11 +40,21 @@ sd_model_map = {
"clip": {
"initializer": clip.export_clip_model,
"external_weight_file": None,
"ireec_flags": ["--iree-flow-collapse-reduction-dims"],
"ireec_flags": [
"--iree-flow-collapse-reduction-dims",
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
],
},
"vae_encode": {
"initializer": vae.export_vae_model,
"external_weight_file": None,
"ireec_flags": [
"--iree-flow-collapse-reduction-dims",
"--iree-opt-const-expr-hoisting=False",
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))",
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
],
},
"unet": {
"initializer": unet.export_unet_model,
@@ -52,6 +62,7 @@ sd_model_map = {
"--iree-flow-collapse-reduction-dims",
"--iree-opt-const-expr-hoisting=False",
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))",
],
"external_weight_file": None,
},
@@ -62,6 +73,8 @@ sd_model_map = {
"--iree-flow-collapse-reduction-dims",
"--iree-opt-const-expr-hoisting=False",
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))",
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
],
},
}
@@ -73,12 +86,6 @@ class StableDiffusion(SharkPipelineBase):
# aims to be as general as possible, and the class will infer and compile
# a list of necessary modules or a combined "pipeline module" for a
# specified job based on the inference task.
#
# custom_model_ids: a dict of submodel + HF ID pairs for custom submodels.
# e.g. {"vae_decode": "madebyollin/sdxl-vae-fp16-fix"}
#
# embeddings: a dict of embedding checkpoints or model IDs to use when
# initializing the compiled modules.
def __init__(
self,
@@ -101,7 +108,10 @@ class StableDiffusion(SharkPipelineBase):
self.width = width
self.scheduler_obj = {}
static_kwargs = {
"pipe": {},
"pipe": {
"external_weight_path": get_checkpoints_path(),
"external_weights": "safetensors",
},
"clip": {"hf_model_name": base_model_id},
"unet": {
"hf_model_name": base_model_id,
@@ -114,12 +124,14 @@ class StableDiffusion(SharkPipelineBase):
"height": height,
"width": width,
"precision": precision,
"max_length": 77,
"max_length": self.model_max_length,
},
"vae_encode": {
"hf_model_name": custom_vae if custom_vae else base_model_id,
"hf_model_name": base_model_id,
"vae_model": vae.VaeModel(
hf_model_name=base_model_id, hf_auth_token=None
hf_model_name=base_model_id,
base_vae=False,
custom_vae=custom_vae,
),
"batch_size": batch_size,
"height": height,
@@ -127,9 +139,11 @@ class StableDiffusion(SharkPipelineBase):
"precision": precision,
},
"vae_decode": {
"hf_model_name": custom_vae if custom_vae else base_model_id,
"hf_model_name": base_model_id,
"vae_model": vae.VaeModel(
hf_model_name=base_model_id, hf_auth_token=None
hf_model_name=base_model_id,
base_vae=False,
custom_vae=custom_vae,
),
"batch_size": batch_size,
"height": height,
@@ -163,7 +177,9 @@ class StableDiffusion(SharkPipelineBase):
)
self.is_img2img = is_img2img
schedulers = get_schedulers(self.base_model_id)
self.weights_path = get_checkpoints_path(self.safe_name(self.pipe_id))
self.weights_path = get_checkpoints_path(
os.path.join("..", self.safe_name(self.pipe_id))
)
if not os.path.exists(self.weights_path):
os.mkdir(self.weights_path)
self.scheduler = schedulers[scheduler]
@@ -197,7 +213,6 @@ class StableDiffusion(SharkPipelineBase):
seed,
ondemand,
repeatable_seeds,
use_base_vae,
resample_type,
control_mode,
hints,
@@ -262,7 +277,7 @@ class StableDiffusion(SharkPipelineBase):
for i in tqdm(range(0, latents.shape[0], self.batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + self.batch_size],
use_base_vae=use_base_vae,
use_base_vae=False,
cpu_scheduling=True,
)
all_imgs.extend(imgs)
@@ -295,7 +310,10 @@ class StableDiffusion(SharkPipelineBase):
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
pad = (0, 0) * (len(text_embeddings.shape) - 2)
pad = pad + (0, self.static_kwargs["unet"]["max_length"] - text_embeddings.shape[1])
pad = pad + (
0,
self.static_kwargs["unet"]["max_length"] - text_embeddings.shape[1],
)
text_embeddings = torch.nn.functional.pad(text_embeddings, pad)
# SHARK: Report clip inference time
@@ -378,9 +396,7 @@ class StableDiffusion(SharkPipelineBase):
if mask is not None and masked_image_latents is not None:
latent_model_input = torch.cat(
[
torch.from_numpy(np.asarray(latent_model_input)).to(
self.dtype
),
torch.from_numpy(np.asarray(latent_model_input)).to(self.dtype),
mask,
masked_image_latents,
],
@@ -430,7 +446,7 @@ class StableDiffusion(SharkPipelineBase):
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def decode_latents(self, latents, use_base_vae, cpu_scheduling):
def decode_latents(self, latents, use_base_vae=False, cpu_scheduling=True):
if use_base_vae:
latents = 1 / 0.18215 * latents
@@ -444,14 +460,13 @@ class StableDiffusion(SharkPipelineBase):
vae_inf_time = (time.time() - vae_start) * 1000
# end_profiling(profile_device)
print(f"\n[LOG] VAE Inference time (ms): {vae_inf_time:.3f}")
if use_base_vae:
images = torch.from_numpy(images)
images = (images.detach().cpu() * 255.0).numpy()
images = images.round()
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
pil_images = [Image.fromarray(image) for image in images.numpy()]
pil_images = [Image.fromarray(image).convert("RGB") for image in images.numpy()]
return pil_images
def process_sd_init_image(self, sd_init_image, resample_type):
@@ -501,15 +516,16 @@ def shark_sd_fn_dict_input(
print("[LOG] Submitting Request...")
for key in sd_kwargs:
if sd_kwargs[key] in ["None", "", None, []]:
if sd_kwargs[key] in [None, []]:
sd_kwargs[key] = None
if sd_kwargs[key] in ["None"]:
sd_kwargs[key] = ""
if key == "seed":
sd_kwargs[key] = int(sd_kwargs[key])
for i in range(1):
generated_imgs = yield from shark_sd_fn(**sd_kwargs)
yield generated_imgs
return generated_imgs
def shark_sd_fn(
@@ -528,7 +544,6 @@ def shark_sd_fn(
base_model_id: str,
custom_weights: str,
custom_vae: str,
use_base_vae: bool,
precision: str,
device: str,
ondemand: bool,
@@ -603,7 +618,6 @@ def shark_sd_fn(
"seed": seed,
"ondemand": ondemand,
"repeatable_seeds": repeatable_seeds,
"use_base_vae": use_base_vae,
"resample_type": resample_type,
"control_mode": control_mode,
"hints": hints,

View File

@@ -41,6 +41,9 @@ class SharkPipelineBase:
self.device, self.device_id = clean_device_info(device)
self.import_mlir = import_mlir
self.iree_module_dict = {}
self.tmp_dir = get_resource_path(os.path.join("..", "shark_tmp"))
if not os.path.exists(self.tmp_dir):
os.mkdir(self.tmp_dir)
self.tempfiles = {}
def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None:
@@ -54,7 +57,7 @@ class SharkPipelineBase:
self.pipe_vmfb_path = Path(
os.path.join(get_checkpoints_path(".."), self.pipe_id)
)
self.pipe_vmfb_path.mkdir(parents=True, exist_ok=True)
self.pipe_vmfb_path.mkdir(parents=False, exist_ok=True)
if submodel == "None":
print("\n[LOG] Gathering any pre-compiled artifacts....")
for key in self.model_map:
@@ -63,9 +66,7 @@ class SharkPipelineBase:
self.get_precompiled(pipe_id, submodel)
ireec_flags = []
if submodel in self.iree_module_dict:
if "vmfb" in self.iree_module_dict[submodel]:
print(f"\n[LOG] Executable for {submodel} already loaded...")
return
return
elif "vmfb_path" in self.model_map[submodel]:
return
elif submodel not in self.tempfiles:
@@ -123,8 +124,9 @@ class SharkPipelineBase:
if submodel == "clip":
# clip.export_clip_model returns (torch_ir, tokenizer)
torch_ir = torch_ir[0]
self.tempfiles[submodel] = get_resource_path(
os.path.join("..", "shark_tmp", f"{submodel}.torch.tempfile")
self.tempfiles[submodel] = os.path.join(
self.tmp_dir, f"{submodel}.torch.tempfile"
)
with open(self.tempfiles[submodel], "w+") as f:

View File

@@ -10,11 +10,10 @@
"seed": -1,
"batch_count": 1,
"batch_size": 1,
"scheduler": "EulerDiscrete",
"scheduler": "DDIM",
"base_model_id": "runwayml/stable-diffusion-v1-5",
"custom_weights": null,
"custom_vae": null,
"use_base_vae": true,
"custom_weights": "",
"custom_vae": "",
"precision": "fp16",
"device": "vulkan",
"ondemand": false,

View File

@@ -114,7 +114,6 @@ def pull_sd_configs(
base_model_id,
custom_weights,
custom_vae,
use_base_vae,
precision,
device,
ondemand,
@@ -172,7 +171,6 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
sd_json["base_model_id"],
sd_json["custom_weights"],
sd_json["custom_vae"],
sd_json["use_base_vae"],
sd_json["precision"],
sd_json["device"],
sd_json["ondemand"],
@@ -316,7 +314,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
show_download_button=False,
)
with gr.Column(elem_id="ui_body"):
with gr.Row(variant="compact"):
with gr.Row():
with gr.Column(scale=2, min_width=600):
with gr.Row(equal_height=True):
with gr.Column(scale=3):
@@ -369,12 +367,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
],
visible=True,
)
use_base_vae = gr.Checkbox(
value=False,
label="Baked VAE",
interactive=True,
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
@@ -726,7 +718,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
base_model_id,
custom_weights,
custom_vae,
use_base_vae,
precision,
device,
ondemand,
@@ -761,7 +752,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
base_model_id,
custom_weights,
custom_vae,
use_base_vae,
precision,
device,
ondemand,

View File

@@ -44,6 +44,7 @@ def set_sd_status(value):
def set_pipe_kwargs(value):
global _pipe_kwargs
print(value)
_pipe_kwargs = value
@@ -74,6 +75,7 @@ def get_sd_status():
def get_pipe_kwargs():
global _pipe_kwargs
print(_pipe_kwargs)
return _pipe_kwargs