mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 13:57:54 -05:00
(SDXL) Fix --ondemand and vae scale factor use, and fix VAE flags.
This commit is contained in:
@@ -165,13 +165,8 @@ class SharkifyStableDiffusionModel:
|
||||
self.check_params(max_len, width, height)
|
||||
self.max_len = max_len
|
||||
self.is_sdxl = is_sdxl
|
||||
self.height = height
|
||||
self.width = width
|
||||
if is_sdxl:
|
||||
# We need to scale down the height/width by vae_scale_factor, which
|
||||
# happens to be 8 in this case.
|
||||
self.height = height // 8
|
||||
self.width = width // 8
|
||||
self.height = height // 8
|
||||
self.width = width // 8
|
||||
self.batch_size = batch_size
|
||||
self.custom_weights = custom_weights.strip()
|
||||
self.use_quantize = use_quantize
|
||||
|
||||
@@ -33,6 +33,7 @@ from apps.stable_diffusion.src.utils import (
|
||||
end_profiling,
|
||||
)
|
||||
import sys
|
||||
import gc
|
||||
from typing import List, Optional
|
||||
|
||||
SD_STATE_IDLE = "idle"
|
||||
@@ -189,6 +190,7 @@ class StableDiffusionPipeline:
|
||||
def unload_vae(self):
|
||||
del self.vae
|
||||
self.vae = None
|
||||
gc.collect()
|
||||
|
||||
def encode_prompt_sdxl(
|
||||
self,
|
||||
@@ -327,6 +329,7 @@ class StableDiffusionPipeline:
|
||||
|
||||
if self.ondemand:
|
||||
self.unload_clip_sdxl()
|
||||
gc.collect()
|
||||
|
||||
# TODO: Look into dtype for text_encoder_2!
|
||||
prompt_embeds = prompt_embeds.to(dtype=torch.float32)
|
||||
@@ -387,6 +390,7 @@ class StableDiffusionPipeline:
|
||||
clip_inf_time = (time.time() - clip_inf_start) * 1000
|
||||
if self.ondemand:
|
||||
self.unload_clip()
|
||||
gc.collect()
|
||||
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
|
||||
|
||||
return text_embeddings
|
||||
@@ -499,6 +503,8 @@ class StableDiffusionPipeline:
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
self.unload_unet_512()
|
||||
gc.collect()
|
||||
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
@@ -556,6 +562,8 @@ class StableDiffusionPipeline:
|
||||
break
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
gc.collect()
|
||||
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
@@ -652,7 +660,6 @@ class StableDiffusionPipeline:
|
||||
use_lora,
|
||||
ondemand,
|
||||
)
|
||||
|
||||
return cls(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
|
||||
# #####################################################
|
||||
@@ -765,6 +772,7 @@ class StableDiffusionPipeline:
|
||||
clip_inf_time = (time.time() - clip_inf_start) * 1000
|
||||
if self.ondemand:
|
||||
self.unload_clip()
|
||||
gc.collect()
|
||||
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
|
||||
|
||||
return text_embeddings.numpy()
|
||||
|
||||
@@ -565,9 +565,10 @@ def get_opt_flags(model, precision="fp16"):
|
||||
iree_flags += opt_flags[model][is_tuned][precision][
|
||||
"specified_compilation_flags"
|
||||
][device]
|
||||
# Due to lack of support for multi-reduce, we always collapse reduction
|
||||
# dims before dispatch formation right now.
|
||||
iree_flags += ["--iree-flow-collapse-reduction-dims"]
|
||||
if "vae" not in model:
|
||||
# Due to lack of support for multi-reduce, we always collapse reduction
|
||||
# dims before dispatch formation right now.
|
||||
iree_flags += ["--iree-flow-collapse-reduction-dims"]
|
||||
return iree_flags
|
||||
|
||||
|
||||
|
||||
@@ -147,6 +147,8 @@ def txt2img_sdxl_inf(
|
||||
# For SDXL we set max_length as 77.
|
||||
print("Setting max_length = 77")
|
||||
max_length = 77
|
||||
if global_obj.get_cfg_obj().ondemand:
|
||||
print("Running txt2img in memory efficient mode.")
|
||||
txt2img_sdxl_obj = Text2ImageSDXLPipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
@@ -164,7 +166,7 @@ def txt2img_sdxl_inf(
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
use_quantize=args.use_quantize,
|
||||
ondemand=args.ondemand,
|
||||
ondemand=global_obj.get_cfg_obj().ondemand,
|
||||
)
|
||||
global_obj.set_sd_obj(txt2img_sdxl_obj)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user