mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-25 03:00:12 -04:00
[SDXL] Add SDXL pipeline to SHARK (#1731)
-- This commit adds SDXL pipeline to SHARK. Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
import torch
|
||||
import transformers
|
||||
import time
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Text2ImagePipeline,
|
||||
Text2ImageSDXLPipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
@@ -16,31 +16,62 @@ def main():
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
# TODO: prompt_embeds and text_embeds form base_model.json requires fixing
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
txt2img_obj = Text2ImagePipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
precision=args.precision,
|
||||
max_length=args.max_length,
|
||||
batch_size=args.batch_size,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
use_quantize=args.use_quantize,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
if args.height == 1024:
|
||||
assert (
|
||||
args.width == 1024
|
||||
), "currently we support only 1024x1024 image size via SDXL"
|
||||
assert args.precision == "fp16", "currently we support fp16 for SDXL"
|
||||
# For SDXL we set max_length as 77.
|
||||
args.max_length = 77
|
||||
txt2img_obj = Text2ImageSDXLPipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
precision=args.precision,
|
||||
max_length=args.max_length,
|
||||
batch_size=args.batch_size,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
use_quantize=args.use_quantize,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
args.height <= 768 and args.width <= 768
|
||||
), "height/width not in supported range"
|
||||
txt2img_obj = Text2ImagePipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
precision=args.precision,
|
||||
max_length=args.max_length,
|
||||
batch_size=args.batch_size,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
use_quantize=args.use_quantize,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
|
||||
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
|
||||
for current_batch in range(args.batch_count):
|
||||
|
||||
Reference in New Issue
Block a user