mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Fixes to defaults, file mgmt, VAE
This commit is contained in:
@@ -40,11 +40,21 @@ sd_model_map = {
|
||||
"clip": {
|
||||
"initializer": clip.export_clip_model,
|
||||
"external_weight_file": None,
|
||||
"ireec_flags": ["--iree-flow-collapse-reduction-dims"],
|
||||
"ireec_flags": [
|
||||
"--iree-flow-collapse-reduction-dims",
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
|
||||
],
|
||||
},
|
||||
"vae_encode": {
|
||||
"initializer": vae.export_vae_model,
|
||||
"external_weight_file": None,
|
||||
"ireec_flags": [
|
||||
"--iree-flow-collapse-reduction-dims",
|
||||
"--iree-opt-const-expr-hoisting=False",
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))",
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
|
||||
],
|
||||
},
|
||||
"unet": {
|
||||
"initializer": unet.export_unet_model,
|
||||
@@ -52,6 +62,7 @@ sd_model_map = {
|
||||
"--iree-flow-collapse-reduction-dims",
|
||||
"--iree-opt-const-expr-hoisting=False",
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))",
|
||||
],
|
||||
"external_weight_file": None,
|
||||
},
|
||||
@@ -62,6 +73,8 @@ sd_model_map = {
|
||||
"--iree-flow-collapse-reduction-dims",
|
||||
"--iree-opt-const-expr-hoisting=False",
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))",
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
|
||||
],
|
||||
},
|
||||
}
|
||||
@@ -73,12 +86,6 @@ class StableDiffusion(SharkPipelineBase):
|
||||
# aims to be as general as possible, and the class will infer and compile
|
||||
# a list of necessary modules or a combined "pipeline module" for a
|
||||
# specified job based on the inference task.
|
||||
#
|
||||
# custom_model_ids: a dict of submodel + HF ID pairs for custom submodels.
|
||||
# e.g. {"vae_decode": "madebyollin/sdxl-vae-fp16-fix"}
|
||||
#
|
||||
# embeddings: a dict of embedding checkpoints or model IDs to use when
|
||||
# initializing the compiled modules.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -101,7 +108,10 @@ class StableDiffusion(SharkPipelineBase):
|
||||
self.width = width
|
||||
self.scheduler_obj = {}
|
||||
static_kwargs = {
|
||||
"pipe": {},
|
||||
"pipe": {
|
||||
"external_weight_path": get_checkpoints_path(),
|
||||
"external_weights": "safetensors",
|
||||
},
|
||||
"clip": {"hf_model_name": base_model_id},
|
||||
"unet": {
|
||||
"hf_model_name": base_model_id,
|
||||
@@ -114,12 +124,14 @@ class StableDiffusion(SharkPipelineBase):
|
||||
"height": height,
|
||||
"width": width,
|
||||
"precision": precision,
|
||||
"max_length": 77,
|
||||
"max_length": self.model_max_length,
|
||||
},
|
||||
"vae_encode": {
|
||||
"hf_model_name": custom_vae if custom_vae else base_model_id,
|
||||
"hf_model_name": base_model_id,
|
||||
"vae_model": vae.VaeModel(
|
||||
hf_model_name=base_model_id, hf_auth_token=None
|
||||
hf_model_name=base_model_id,
|
||||
base_vae=False,
|
||||
custom_vae=custom_vae,
|
||||
),
|
||||
"batch_size": batch_size,
|
||||
"height": height,
|
||||
@@ -127,9 +139,11 @@ class StableDiffusion(SharkPipelineBase):
|
||||
"precision": precision,
|
||||
},
|
||||
"vae_decode": {
|
||||
"hf_model_name": custom_vae if custom_vae else base_model_id,
|
||||
"hf_model_name": base_model_id,
|
||||
"vae_model": vae.VaeModel(
|
||||
hf_model_name=base_model_id, hf_auth_token=None
|
||||
hf_model_name=base_model_id,
|
||||
base_vae=False,
|
||||
custom_vae=custom_vae,
|
||||
),
|
||||
"batch_size": batch_size,
|
||||
"height": height,
|
||||
@@ -163,7 +177,9 @@ class StableDiffusion(SharkPipelineBase):
|
||||
)
|
||||
self.is_img2img = is_img2img
|
||||
schedulers = get_schedulers(self.base_model_id)
|
||||
self.weights_path = get_checkpoints_path(self.safe_name(self.pipe_id))
|
||||
self.weights_path = get_checkpoints_path(
|
||||
os.path.join("..", self.safe_name(self.pipe_id))
|
||||
)
|
||||
if not os.path.exists(self.weights_path):
|
||||
os.mkdir(self.weights_path)
|
||||
self.scheduler = schedulers[scheduler]
|
||||
@@ -197,7 +213,6 @@ class StableDiffusion(SharkPipelineBase):
|
||||
seed,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
use_base_vae,
|
||||
resample_type,
|
||||
control_mode,
|
||||
hints,
|
||||
@@ -262,7 +277,7 @@ class StableDiffusion(SharkPipelineBase):
|
||||
for i in tqdm(range(0, latents.shape[0], self.batch_size)):
|
||||
imgs = self.decode_latents(
|
||||
latents=latents[i : i + self.batch_size],
|
||||
use_base_vae=use_base_vae,
|
||||
use_base_vae=False,
|
||||
cpu_scheduling=True,
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
@@ -295,7 +310,10 @@ class StableDiffusion(SharkPipelineBase):
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
pad = (0, 0) * (len(text_embeddings.shape) - 2)
|
||||
pad = pad + (0, self.static_kwargs["unet"]["max_length"] - text_embeddings.shape[1])
|
||||
pad = pad + (
|
||||
0,
|
||||
self.static_kwargs["unet"]["max_length"] - text_embeddings.shape[1],
|
||||
)
|
||||
text_embeddings = torch.nn.functional.pad(text_embeddings, pad)
|
||||
|
||||
# SHARK: Report clip inference time
|
||||
@@ -378,9 +396,7 @@ class StableDiffusion(SharkPipelineBase):
|
||||
if mask is not None and masked_image_latents is not None:
|
||||
latent_model_input = torch.cat(
|
||||
[
|
||||
torch.from_numpy(np.asarray(latent_model_input)).to(
|
||||
self.dtype
|
||||
),
|
||||
torch.from_numpy(np.asarray(latent_model_input)).to(self.dtype),
|
||||
mask,
|
||||
masked_image_latents,
|
||||
],
|
||||
@@ -430,7 +446,7 @@ class StableDiffusion(SharkPipelineBase):
|
||||
all_latents = torch.cat(latent_history, dim=0)
|
||||
return all_latents
|
||||
|
||||
def decode_latents(self, latents, use_base_vae, cpu_scheduling):
|
||||
def decode_latents(self, latents, use_base_vae=False, cpu_scheduling=True):
|
||||
if use_base_vae:
|
||||
latents = 1 / 0.18215 * latents
|
||||
|
||||
@@ -444,14 +460,13 @@ class StableDiffusion(SharkPipelineBase):
|
||||
vae_inf_time = (time.time() - vae_start) * 1000
|
||||
# end_profiling(profile_device)
|
||||
print(f"\n[LOG] VAE Inference time (ms): {vae_inf_time:.3f}")
|
||||
|
||||
if use_base_vae:
|
||||
images = torch.from_numpy(images)
|
||||
images = (images.detach().cpu() * 255.0).numpy()
|
||||
images = images.round()
|
||||
|
||||
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
|
||||
pil_images = [Image.fromarray(image) for image in images.numpy()]
|
||||
pil_images = [Image.fromarray(image).convert("RGB") for image in images.numpy()]
|
||||
return pil_images
|
||||
|
||||
def process_sd_init_image(self, sd_init_image, resample_type):
|
||||
@@ -501,15 +516,16 @@ def shark_sd_fn_dict_input(
|
||||
print("[LOG] Submitting Request...")
|
||||
|
||||
for key in sd_kwargs:
|
||||
if sd_kwargs[key] in ["None", "", None, []]:
|
||||
if sd_kwargs[key] in [None, []]:
|
||||
sd_kwargs[key] = None
|
||||
if sd_kwargs[key] in ["None"]:
|
||||
sd_kwargs[key] = ""
|
||||
if key == "seed":
|
||||
sd_kwargs[key] = int(sd_kwargs[key])
|
||||
|
||||
for i in range(1):
|
||||
generated_imgs = yield from shark_sd_fn(**sd_kwargs)
|
||||
yield generated_imgs
|
||||
return generated_imgs
|
||||
|
||||
|
||||
def shark_sd_fn(
|
||||
@@ -528,7 +544,6 @@ def shark_sd_fn(
|
||||
base_model_id: str,
|
||||
custom_weights: str,
|
||||
custom_vae: str,
|
||||
use_base_vae: bool,
|
||||
precision: str,
|
||||
device: str,
|
||||
ondemand: bool,
|
||||
@@ -603,7 +618,6 @@ def shark_sd_fn(
|
||||
"seed": seed,
|
||||
"ondemand": ondemand,
|
||||
"repeatable_seeds": repeatable_seeds,
|
||||
"use_base_vae": use_base_vae,
|
||||
"resample_type": resample_type,
|
||||
"control_mode": control_mode,
|
||||
"hints": hints,
|
||||
|
||||
@@ -41,6 +41,9 @@ class SharkPipelineBase:
|
||||
self.device, self.device_id = clean_device_info(device)
|
||||
self.import_mlir = import_mlir
|
||||
self.iree_module_dict = {}
|
||||
self.tmp_dir = get_resource_path(os.path.join("..", "shark_tmp"))
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.mkdir(self.tmp_dir)
|
||||
self.tempfiles = {}
|
||||
|
||||
def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None:
|
||||
@@ -54,7 +57,7 @@ class SharkPipelineBase:
|
||||
self.pipe_vmfb_path = Path(
|
||||
os.path.join(get_checkpoints_path(".."), self.pipe_id)
|
||||
)
|
||||
self.pipe_vmfb_path.mkdir(parents=True, exist_ok=True)
|
||||
self.pipe_vmfb_path.mkdir(parents=False, exist_ok=True)
|
||||
if submodel == "None":
|
||||
print("\n[LOG] Gathering any pre-compiled artifacts....")
|
||||
for key in self.model_map:
|
||||
@@ -63,9 +66,7 @@ class SharkPipelineBase:
|
||||
self.get_precompiled(pipe_id, submodel)
|
||||
ireec_flags = []
|
||||
if submodel in self.iree_module_dict:
|
||||
if "vmfb" in self.iree_module_dict[submodel]:
|
||||
print(f"\n[LOG] Executable for {submodel} already loaded...")
|
||||
return
|
||||
return
|
||||
elif "vmfb_path" in self.model_map[submodel]:
|
||||
return
|
||||
elif submodel not in self.tempfiles:
|
||||
@@ -123,8 +124,9 @@ class SharkPipelineBase:
|
||||
if submodel == "clip":
|
||||
# clip.export_clip_model returns (torch_ir, tokenizer)
|
||||
torch_ir = torch_ir[0]
|
||||
self.tempfiles[submodel] = get_resource_path(
|
||||
os.path.join("..", "shark_tmp", f"{submodel}.torch.tempfile")
|
||||
|
||||
self.tempfiles[submodel] = os.path.join(
|
||||
self.tmp_dir, f"{submodel}.torch.tempfile"
|
||||
)
|
||||
|
||||
with open(self.tempfiles[submodel], "w+") as f:
|
||||
|
||||
@@ -10,11 +10,10 @@
|
||||
"seed": -1,
|
||||
"batch_count": 1,
|
||||
"batch_size": 1,
|
||||
"scheduler": "EulerDiscrete",
|
||||
"scheduler": "DDIM",
|
||||
"base_model_id": "runwayml/stable-diffusion-v1-5",
|
||||
"custom_weights": null,
|
||||
"custom_vae": null,
|
||||
"use_base_vae": true,
|
||||
"custom_weights": "",
|
||||
"custom_vae": "",
|
||||
"precision": "fp16",
|
||||
"device": "vulkan",
|
||||
"ondemand": false,
|
||||
|
||||
@@ -114,7 +114,6 @@ def pull_sd_configs(
|
||||
base_model_id,
|
||||
custom_weights,
|
||||
custom_vae,
|
||||
use_base_vae,
|
||||
precision,
|
||||
device,
|
||||
ondemand,
|
||||
@@ -172,7 +171,6 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
|
||||
sd_json["base_model_id"],
|
||||
sd_json["custom_weights"],
|
||||
sd_json["custom_vae"],
|
||||
sd_json["use_base_vae"],
|
||||
sd_json["precision"],
|
||||
sd_json["device"],
|
||||
sd_json["ondemand"],
|
||||
@@ -316,7 +314,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
show_download_button=False,
|
||||
)
|
||||
with gr.Column(elem_id="ui_body"):
|
||||
with gr.Row(variant="compact"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2, min_width=600):
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Column(scale=3):
|
||||
@@ -369,12 +367,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
use_base_vae = gr.Checkbox(
|
||||
value=False,
|
||||
label="Baked VAE",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
@@ -726,7 +718,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
base_model_id,
|
||||
custom_weights,
|
||||
custom_vae,
|
||||
use_base_vae,
|
||||
precision,
|
||||
device,
|
||||
ondemand,
|
||||
@@ -761,7 +752,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
base_model_id,
|
||||
custom_weights,
|
||||
custom_vae,
|
||||
use_base_vae,
|
||||
precision,
|
||||
device,
|
||||
ondemand,
|
||||
|
||||
@@ -44,6 +44,7 @@ def set_sd_status(value):
|
||||
|
||||
def set_pipe_kwargs(value):
|
||||
global _pipe_kwargs
|
||||
print(value)
|
||||
_pipe_kwargs = value
|
||||
|
||||
|
||||
@@ -74,6 +75,7 @@ def get_sd_status():
|
||||
|
||||
def get_pipe_kwargs():
|
||||
global _pipe_kwargs
|
||||
print(_pipe_kwargs)
|
||||
return _pipe_kwargs
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user