Enable tuner for upscaler unet. (#1563)

This commit is contained in:
Ean Garvey
2023-06-20 13:40:13 -05:00
committed by GitHub
parent 0def74f520
commit ccf944c1bd

View File

@@ -17,6 +17,10 @@ from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
def load_mlir_module():
if "upscaler" in args.hf_model_id:
is_upscaler = True
else:
is_upscaler = False
sd_model = SharkifyStableDiffusionModel(
args.hf_model_id,
args.ckpt_loc,
@@ -27,6 +31,7 @@ def load_mlir_module():
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
is_upscaler=is_upscaler,
use_tuned=False,
low_cpu_mem_usage=args.low_cpu_mem_usage,
return_mlir=True,