(SDXL) Fix --ondemand and vae scale factor use, and fix VAE flags.

This commit is contained in:
Ean Garvey
2023-12-01 00:38:31 -06:00
parent fdd0c52575
commit 1df48421d9
4 changed files with 18 additions and 12 deletions

View File

@@ -165,13 +165,8 @@ class SharkifyStableDiffusionModel:
self.check_params(max_len, width, height)
self.max_len = max_len
self.is_sdxl = is_sdxl
self.height = height
self.width = width
if is_sdxl:
# We need to scale down the height/width by vae_scale_factor, which
# happens to be 8 in this case.
self.height = height // 8
self.width = width // 8
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size
self.custom_weights = custom_weights.strip()
self.use_quantize = use_quantize

View File

@@ -33,6 +33,7 @@ from apps.stable_diffusion.src.utils import (
end_profiling,
)
import sys
import gc
from typing import List, Optional
SD_STATE_IDLE = "idle"
@@ -189,6 +190,7 @@ class StableDiffusionPipeline:
def unload_vae(self):
del self.vae
self.vae = None
gc.collect()
def encode_prompt_sdxl(
self,
@@ -327,6 +329,7 @@ class StableDiffusionPipeline:
if self.ondemand:
self.unload_clip_sdxl()
gc.collect()
# TODO: Look into dtype for text_encoder_2!
prompt_embeds = prompt_embeds.to(dtype=torch.float32)
@@ -387,6 +390,7 @@ class StableDiffusionPipeline:
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
gc.collect()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings
@@ -499,6 +503,8 @@ class StableDiffusionPipeline:
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
gc.collect()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -556,6 +562,8 @@ class StableDiffusionPipeline:
break
if self.ondemand:
self.unload_unet()
gc.collect()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -652,7 +660,6 @@ class StableDiffusionPipeline:
use_lora,
ondemand,
)
return cls(scheduler, sd_model, import_mlir, use_lora, ondemand)
# #####################################################
@@ -765,6 +772,7 @@ class StableDiffusionPipeline:
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
gc.collect()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings.numpy()

View File

@@ -565,9 +565,10 @@ def get_opt_flags(model, precision="fp16"):
iree_flags += opt_flags[model][is_tuned][precision][
"specified_compilation_flags"
][device]
# Due to lack of support for multi-reduce, we always collapse reduction
# dims before dispatch formation right now.
iree_flags += ["--iree-flow-collapse-reduction-dims"]
if "vae" not in model:
# Due to lack of support for multi-reduce, we always collapse reduction
# dims before dispatch formation right now.
iree_flags += ["--iree-flow-collapse-reduction-dims"]
return iree_flags

View File

@@ -147,6 +147,8 @@ def txt2img_sdxl_inf(
# For SDXL we set max_length as 77.
print("Setting max_length = 77")
max_length = 77
if global_obj.get_cfg_obj().ondemand:
print("Running txt2img in memory efficient mode.")
txt2img_sdxl_obj = Text2ImageSDXLPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
@@ -164,7 +166,7 @@ def txt2img_sdxl_inf(
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=args.ondemand,
ondemand=global_obj.get_cfg_obj().ondemand,
)
global_obj.set_sd_obj(txt2img_sdxl_obj)