Tweaks to VAE, pipeline states

This commit is contained in:
Ean Garvey
2023-12-19 22:37:56 -06:00
parent a06adc4eb2
commit 12884591a5
5 changed files with 148 additions and 214 deletions

View File

@@ -33,21 +33,22 @@ from apps.shark_studio.modules.ckpt_processing import (
process_custom_pipe_weights,
)
from transformers import CLIPTokenizer
from diffusers.image_processor import VaeImageProcessor
from math import ceil
from PIL import Image
sd_model_map = {
"clip": {
"initializer": clip.export_clip_model,
"external_weight_file": None,
"ireec_flags": [
"--iree-flow-collapse-reduction-dims",
"--iree-opt-const-expr-hoisting=False",
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
],
},
"vae_encode": {
"initializer": vae.export_vae_model,
"external_weight_file": None,
"ireec_flags": [
"--iree-flow-collapse-reduction-dims",
"--iree-opt-const-expr-hoisting=False",
@@ -64,11 +65,9 @@ sd_model_map = {
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))",
],
"external_weight_file": None,
},
"vae_decode": {
"initializer": vae.export_vae_model,
"external_weight_file": None,
"ireec_flags": [
"--iree-flow-collapse-reduction-dims",
"--iree-opt-const-expr-hoisting=False",
@@ -109,7 +108,6 @@ class StableDiffusion(SharkPipelineBase):
self.scheduler_obj = {}
static_kwargs = {
"pipe": {
"external_weight_path": get_checkpoints_path(),
"external_weights": "safetensors",
},
"clip": {"hf_model_name": base_model_id},
@@ -130,7 +128,6 @@ class StableDiffusion(SharkPipelineBase):
"hf_model_name": base_model_id,
"vae_model": vae.VaeModel(
hf_model_name=base_model_id,
base_vae=False,
custom_vae=custom_vae,
),
"batch_size": batch_size,
@@ -142,7 +139,6 @@ class StableDiffusion(SharkPipelineBase):
"hf_model_name": base_model_id,
"vae_model": vae.VaeModel(
hf_model_name=base_model_id,
base_vae=False,
custom_vae=custom_vae,
),
"batch_size": batch_size,
@@ -177,115 +173,28 @@ class StableDiffusion(SharkPipelineBase):
)
self.is_img2img = is_img2img
schedulers = get_schedulers(self.base_model_id)
self.weights_path = get_checkpoints_path(
os.path.join("..", self.safe_name(self.pipe_id))
)
self.scheduler = schedulers[scheduler]
self.image_processor = VaeImageProcessor()#do_convert_rgb=True)
self.weights_path = os.path.join(get_checkpoints_path(), self.safe_name(self.base_model_id))
if not os.path.exists(self.weights_path):
os.mkdir(self.weights_path)
self.scheduler = schedulers[scheduler]
print(f"[LOG] Loaded scheduler: {scheduler}")
for model in adapters:
self.model_map[model] = adapters[model]
if custom_weights:
if os.path.isfile(custom_weights):
for i in self.model_map:
self.model_map[i]["external_weights_file"] = None
elif custom_weights:
print(
f"\n[LOG][WARNING] Custom weights were not found at {custom_weights}. Did you mean to pass a base model ID?"
)
self.static_kwargs["pipe"] = {
# "external_weight_path": self.weights_path,
# "external_weights": "safetensors",
}
for submodel in self.static_kwargs:
if custom_weights:
if submodel not in ["clip", "clip2"]:
self.static_kwargs[submodel]["external_weight_file"] = custom_weights
else:
self.static_kwargs[submodel]["external_weight_path"] = os.path.join(self.weights_path, submodel + ".safetensors")
else:
self.static_kwargs[submodel]["external_weight_path"] = os.path.join(self.weights_path, submodel + ".safetensors")
self.get_compiled_map(pipe_id=self.pipe_id)
print("\n[LOG] Pipeline successfully prepared for runtime.")
return
def generate_images(
self,
prompt,
negative_prompt,
image,
steps,
strength,
guidance_scale,
seed,
ondemand,
repeatable_seeds,
resample_type,
control_mode,
hints,
):
# TODO: Batched args
self.ondemand = ondemand
if self.is_img2img:
image, _ = self.process_sd_init_image(image, resample_type)
else:
image = None
print("\n[LOG] Generating images...")
batched_args = [
prompt,
negative_prompt,
# steps,
# strength,
# guidance_scale,
# seed,
# resample_type,
# control_mode,
# hints,
]
for arg in batched_args:
if not isinstance(arg, list):
arg = [arg] * self.batch_size
if len(arg) < self.batch_size:
arg = arg * self.batch_size
else:
arg = [arg[i] for i in range(self.batch_size)]
text_embeddings = self.encode_prompts_weight(
prompt,
negative_prompt,
)
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
init_latents, final_timesteps = self.prepare_latents(
generator=generator,
num_inference_steps=steps,
image=image,
strength=strength,
)
latents = self.produce_img_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=final_timesteps,
cpu_scheduling=True, # until we have schedulers through Turbine
)
# Img latents -> PIL images
all_imgs = []
self.load_submodels(["vae_decode"])
for i in tqdm(range(0, latents.shape[0], self.batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + self.batch_size],
use_base_vae=False,
cpu_scheduling=True,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_submodels(["vae_decode"])
return all_imgs
def encode_prompts_weight(
self,
prompt,
@@ -446,10 +355,7 @@ class StableDiffusion(SharkPipelineBase):
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def decode_latents(self, latents, use_base_vae=False, cpu_scheduling=True):
if use_base_vae:
latents = 1 / 0.18215 * latents
def decode_latents(self, latents, cpu_scheduling=True):
latents_numpy = latents.to(self.dtype)
if cpu_scheduling:
latents_numpy = latents.detach().numpy()
@@ -460,13 +366,9 @@ class StableDiffusion(SharkPipelineBase):
vae_inf_time = (time.time() - vae_start) * 1000
# end_profiling(profile_device)
print(f"\n[LOG] VAE Inference time (ms): {vae_inf_time:.3f}")
if use_base_vae:
images = torch.from_numpy(images)
images = (images.detach().cpu() * 255.0).numpy()
images = images.round()
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
pil_images = [Image.fromarray(image).convert("RGB") for image in images.numpy()]
images = torch.from_numpy(images).permute(0, 2, 3, 1).float().numpy()
pil_images = self.image_processor.numpy_to_pil(images)
return pil_images
def process_sd_init_image(self, sd_init_image, resample_type):
@@ -508,6 +410,83 @@ class StableDiffusion(SharkPipelineBase):
is_img2img = True
image = image_arr
return image, is_img2img
def generate_images(
self,
prompt,
negative_prompt,
image,
steps,
strength,
guidance_scale,
seed,
ondemand,
repeatable_seeds,
resample_type,
control_mode,
hints,
):
# TODO: Batched args
self.ondemand = ondemand
if self.is_img2img:
image, _ = self.process_sd_init_image(image, resample_type)
else:
image = None
print("\n[LOG] Generating images...")
batched_args = [
prompt,
negative_prompt,
image,
]
for arg in batched_args:
if not isinstance(arg, list):
arg = [arg] * self.batch_size
if len(arg) < self.batch_size:
arg = arg * self.batch_size
else:
arg = [arg[i] for i in range(self.batch_size)]
text_embeddings = self.encode_prompts_weight(
prompt,
negative_prompt,
)
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
init_latents, final_timesteps = self.prepare_latents(
generator=generator,
num_inference_steps=steps,
image=image,
strength=strength,
)
latents = self.produce_img_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=final_timesteps,
cpu_scheduling=True, # until we have schedulers through Turbine
)
# Img latents -> PIL images
all_imgs = []
self.load_submodels(["vae_decode"])
for i in tqdm(range(0, latents.shape[0], self.batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + self.batch_size],
cpu_scheduling=True,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_submodels(["vae_decode"])
return all_imgs
def shark_sd_fn_dict_input(
@@ -622,7 +601,6 @@ def shark_sd_fn(
"control_mode": control_mode,
"hints": hints,
}
print(submit_pipe_kwargs)
if (
not global_obj.get_sd_obj()
or global_obj.get_pipe_kwargs() != submit_pipe_kwargs
@@ -630,7 +608,6 @@ def shark_sd_fn(
print("\n[LOG] Initializing new pipeline...")
global_obj.clear_cache()
gc.collect()
global_obj.set_pipe_kwargs(submit_pipe_kwargs)
# Initializes the pipeline and retrieves IR based on all
# parameters that are static in the turbine output format,
@@ -640,8 +617,14 @@ def shark_sd_fn(
**submit_pipe_kwargs,
)
global_obj.set_sd_obj(sd_pipe)
global_obj.set_pipe_kwargs(submit_pipe_kwargs)
if (
not global_obj.get_prep_kwargs()
or global_obj.get_prep_kwargs() != submit_prep_kwargs
):
global_obj.set_prep_kwargs(submit_prep_kwargs)
global_obj.get_sd_obj().prepare_pipe(**submit_prep_kwargs)
global_obj.get_sd_obj().prepare_pipe(**submit_prep_kwargs)
generated_imgs = []
for current_batch in range(batch_count):
start_time = time.time()
@@ -653,7 +636,7 @@ def shark_sd_fn(
# break
# else:
save_output_img(
out_imgs[0],
out_imgs[current_batch],
seed,
sd_kwargs,
)

View File

@@ -35,6 +35,7 @@ class SharkPipelineBase:
import_mlir: bool = True,
):
self.model_map = model_map
self.pipe_map = {}
self.static_kwargs = static_kwargs
self.base_model_id = base_model_id
self.triple = get_iree_target_triple(device)
@@ -45,6 +46,7 @@ class SharkPipelineBase:
if not os.path.exists(self.tmp_dir):
os.mkdir(self.tmp_dir)
self.tempfiles = {}
self.pipe_vmfb_path = ""
def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None:
# First checks whether we have .vmfbs precompiled, then populates the map
@@ -63,11 +65,12 @@ class SharkPipelineBase:
for key in self.model_map:
self.get_compiled_map(pipe_id, submodel=key)
else:
self.get_precompiled(pipe_id, submodel)
self.pipe_map[submodel] = {}
self.get_precompiled(self.pipe_id, submodel)
ireec_flags = []
if submodel in self.iree_module_dict:
return
elif "vmfb_path" in self.model_map[submodel]:
elif "vmfb_path" in self.pipe_map[submodel]:
return
elif submodel not in self.tempfiles:
print(
@@ -87,10 +90,8 @@ class SharkPipelineBase:
else []
)
if "external_weights_file" in self.model_map[submodel]:
weights_path = self.model_map[submodel]["external_weights_file"]
else:
weights_path = None
weights_path = self.get_io_params(submodel)
self.iree_module_dict[submodel] = get_iree_compiled_module(
self.tempfiles[submodel],
device=self.device,
@@ -101,20 +102,30 @@ class SharkPipelineBase:
write_to=os.path.join(self.pipe_vmfb_path, submodel + ".vmfb"),
)
return
def get_io_params(self, submodel):
if "external_weight_file" in self.static_kwargs[submodel]:
# we are using custom weights
weights_path = self.static_kwargs[submodel]["external_weight_file"]
elif "external_weight_path" in self.static_kwargs[submodel]:
# we are using the default weights for the HF model
weights_path = self.static_kwargs[submodel]["external_weight_path"]
else:
# assume the torch IR contains the weights.
weights_path = None
return weights_path
def get_precompiled(self, pipe_id, submodel="None"):
if submodel == "None":
for model in self.model_map:
self.get_precompiled(pipe_id, model)
vmfbs = []
vmfb_matches = {}
vmfbs_path = self.pipe_vmfb_path
for dirpath, dirnames, filenames in os.walk(vmfbs_path):
for dirpath, dirnames, filenames in os.walk(self.pipe_vmfb_path):
vmfbs.extend(filenames)
break
for file in vmfbs:
if submodel in file:
self.model_map[submodel]["vmfb_path"] = os.path.join(vmfbs_path, file)
self.pipe_map[submodel]["vmfb_path"] = os.path.join(self.pipe_vmfb_path, file)
return
def import_torch_ir(self, submodel, kwargs):
@@ -139,9 +150,11 @@ class SharkPipelineBase:
for submodel in submodels:
if submodel in self.iree_module_dict:
print(f"\n[LOG] {submodel} is ready for inference.")
if "vmfb_path" in self.model_map[submodel]:
continue
if "vmfb_path" in self.pipe_map[submodel]:
weights_path = self.get_io_params(submodel)
print(
f"\n[LOG] Loading .vmfb for {submodel} from {self.model_map[submodel]['vmfb_path']}"
f"\n[LOG] Loading .vmfb for {submodel} from {self.pipe_map[submodel]['vmfb_path']}"
)
self.iree_module_dict[submodel] = {}
(
@@ -149,13 +162,11 @@ class SharkPipelineBase:
self.iree_module_dict[submodel]["config"],
self.iree_module_dict[submodel]["temp_file_to_unlink"],
) = load_vmfb_using_mmap(
self.model_map[submodel]["vmfb_path"],
self.pipe_map[submodel]["vmfb_path"],
self.device,
device_idx=0,
rt_flags=[],
external_weight_file=self.model_map[submodel][
"external_weight_file"
],
external_weight_file=weights_path,
)
else:
self.get_compiled_map(self.pipe_id, submodel)

View File

@@ -10,7 +10,7 @@
"seed": -1,
"batch_count": 1,
"batch_size": 1,
"scheduler": "DDIM",
"scheduler": "EulerDiscrete",
"base_model_id": "runwayml/stable-diffusion-v1-5",
"custom_weights": "",
"custom_vae": "",

View File

@@ -226,76 +226,6 @@ def import_original(original_img, width, height):
return EditorValue(img_dict)
# def update_cn_input(
# model,
# stencil,
# preprocessed_hint,
# ):
# print("update_cn_input")
# if model == None:
# stencil = None
# preprocessed_hint = None
# return [
# gr.update(),
# gr.update(),
# gr.update(),
# gr.update(),
# gr.update(),
# gr.update(),
# stencil,
# preprocessed_hint,
# ]
# elif model == "scribble":
# return [
# gr.ImageEditor(
# visible=True,
# interactive=True,
# show_label=False,
# image_mode="RGB",
# type="pil",
# brush=Brush(
# colors=["#000000"],
# color_mode="fixed",
# default_size=5,
# ),
# ),
# gr.Image(
# visible=True,
# show_label=False,
# interactive=True,
# show_download_button=False,
# ),
# gr.Slider(visible=True, label="Canvas Width"),
# gr.Slider(visible=True, label="Canvas Height"),
# gr.Button(visible=True),
# gr.Button(visible=False),
# stencil,
# preprocessed_hint,
# ]
# else:
# return [
# gr.ImageEditor(
# visible=True,
# interactive=True,
# show_label=False,
# image_mode="RGB",
# type="pil",
# ),
# gr.Image(
# visible=True,
# show_label=False,
# interactive=True,
# show_download_button=False,
# ),
# gr.Slider(visible=True, label="Canvas Width"),
# gr.Slider(visible=True, label="Canvas Height"),
# gr.Button(visible=False),
# gr.Button(visible=True),
# stencil,
# preprocessed_hint,
# ]
with gr.Blocks(title="Stable Diffusion") as sd_element:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)

View File

@@ -12,11 +12,13 @@ def _init():
global _sd_obj
global _devices
global _pipe_kwargs
global _prep_kwargs
global _gen_kwargs
global _schedulers
_sd_obj = None
_devices = None
_pipe_kwargs = None
_prep_kwargs = None
_gen_kwargs = None
_schedulers = None
set_devices()
@@ -44,15 +46,19 @@ def set_sd_status(value):
def set_pipe_kwargs(value):
global _pipe_kwargs
print(value)
_pipe_kwargs = value
def set_prep_kwargs(value):
global _prep_kwargs
_prep_kwargs = value
def set_gen_kwargs(value):
global _gen_kwargs
_gen_kwargs = value
def set_schedulers(value):
global _schedulers
_schedulers = value
@@ -75,10 +81,14 @@ def get_sd_status():
def get_pipe_kwargs():
global _pipe_kwargs
print(_pipe_kwargs)
return _pipe_kwargs
def get_prep_kwargs():
global _prep_kwargs
return _prep_kwargs
def get_gen_kwargs():
global _gen_kwargs
return _gen_kwargs
@@ -92,14 +102,14 @@ def get_scheduler(key):
def clear_cache():
global _sd_obj
global _pipe_kwargs
global _prep_kwargs
global _gen_kwargs
global _schedulers
del _sd_obj
del _pipe_kwargs
del _gen_kwargs
del _schedulers
gc.collect()
_sd_obj = None
_pipe_kwargs = None
_prep_kwargs = None
_gen_kwargs = None
_schedulers = None