mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-02-19 11:56:43 -05:00
Sd3 initial support via turbine pipeline
This commit is contained in:
@@ -39,17 +39,25 @@ EMPTY_SD_MAP = {
|
||||
|
||||
EMPTY_SDXL_MAP = {
|
||||
"prompt_encoder": None,
|
||||
"scheduled_unet": None,
|
||||
"unet": None,
|
||||
"vae_decode": None,
|
||||
"pipeline": None,
|
||||
"full_pipeline": None,
|
||||
"scheduler": None,
|
||||
}
|
||||
|
||||
EMPTY_SD3_MAP = {
|
||||
"clip": None,
|
||||
"mmdit": None,
|
||||
"vae": None,
|
||||
"scheduler": None,
|
||||
}
|
||||
|
||||
EMPTY_FLAGS = {
|
||||
"clip": None,
|
||||
"unet": None,
|
||||
"mmdit": None,
|
||||
"vae": None,
|
||||
"pipeline": None,
|
||||
"scheduler": None,
|
||||
}
|
||||
|
||||
|
||||
@@ -86,21 +94,50 @@ class StableDiffusion:
|
||||
scheduler: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
clip_device: str = None,
|
||||
vae_device: str = None,
|
||||
target_triple: str = None,
|
||||
custom_vae: str = None,
|
||||
num_loras: int = 0,
|
||||
import_ir: bool = True,
|
||||
is_controlled: bool = False,
|
||||
external_weights: str = "safetensors",
|
||||
vae_precision: str = "fp16",
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
progress(0, desc="Initializing pipeline...")
|
||||
self.ui_device = device
|
||||
backend, target = parse_device(device, target_triple)
|
||||
if clip_device:
|
||||
clip_device, clip_target = parse_device(clip_device)
|
||||
else:
|
||||
clip_device, clip_target = backend, target
|
||||
if vae_device:
|
||||
vae_device, vae_target = parse_device(vae_device)
|
||||
else:
|
||||
vae_device, vae_target = backend, target
|
||||
devices = {
|
||||
"clip": clip_device,
|
||||
"mmdit": backend,
|
||||
"vae": vae_device,
|
||||
}
|
||||
targets = {
|
||||
"clip": clip_target,
|
||||
"mmdit": target,
|
||||
"vae": vae_target,
|
||||
}
|
||||
pipe_device_id = backend
|
||||
target_triple = target
|
||||
for key in devices:
|
||||
if devices[key] != backend:
|
||||
pipe_device_id = "hybrid"
|
||||
target_triple = "_".join([clip_target, target, vae_target])
|
||||
self.precision = precision
|
||||
self.compiled_pipeline = False
|
||||
self.base_model_id = base_model_id
|
||||
self.custom_vae = custom_vae
|
||||
self.is_sdxl = "xl" in self.base_model_id.lower()
|
||||
self.is_sd3 = "stable-diffusion-3" in self.base_model_id.lower()
|
||||
self.is_custom = ".py" in self.base_model_id.lower()
|
||||
if self.is_custom:
|
||||
custom_module = load_script(
|
||||
@@ -115,23 +152,32 @@ class StableDiffusion:
|
||||
SharkSDXLPipeline,
|
||||
)
|
||||
self.turbine_pipe = SharkSDXLPipeline
|
||||
self.dynamic_steps = False
|
||||
self.dynamic_steps = True
|
||||
self.model_map = EMPTY_SDXL_MAP
|
||||
elif self.is_sd3:
|
||||
from turbine_models.custom_models.sd3_inference.sd3_pipeline import SharkSD3Pipeline, empty_pipe_dict
|
||||
|
||||
self.turbine_pipe = SharkSD3Pipeline
|
||||
self.dynamic_steps = True
|
||||
self.model_map = EMPTY_SD3_MAP
|
||||
else:
|
||||
from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline
|
||||
|
||||
self.turbine_pipe = SharkSDPipeline
|
||||
self.dynamic_steps = True
|
||||
self.model_map = EMPTY_SD_MAP
|
||||
# no multi-device yet
|
||||
devices = backend
|
||||
targets = target
|
||||
max_length = 64
|
||||
target_backend, self.rt_device, triple = parse_device(device, target_triple)
|
||||
|
||||
pipe_id_list = [
|
||||
safe_name(base_model_id),
|
||||
str(batch_size),
|
||||
str(max_length),
|
||||
f"{str(height)}x{str(width)}",
|
||||
precision,
|
||||
triple,
|
||||
target_triple,
|
||||
]
|
||||
if num_loras > 0:
|
||||
pipe_id_list.append(str(num_loras) + "lora")
|
||||
@@ -151,16 +197,16 @@ class StableDiffusion:
|
||||
|
||||
decomp_attn = True
|
||||
attn_spec = None
|
||||
if triple in ["gfx940", "gfx942", "gfx90a"]:
|
||||
if target_triple in ["gfx940", "gfx942", "gfx90a"]:
|
||||
decomp_attn = False
|
||||
attn_spec = "mfma"
|
||||
elif triple in ["gfx1100", "gfx1103", "gfx1150"]:
|
||||
elif target in ["gfx1100", "gfx1103", "gfx1150"]:
|
||||
decomp_attn = False
|
||||
attn_spec = "wmma"
|
||||
if triple in ["gfx1103", "gfx1150"]:
|
||||
if target in ["gfx1103", "gfx1150"]:
|
||||
# external weights have issues on igpu
|
||||
external_weights = None
|
||||
elif target_backend == "llvm-cpu":
|
||||
elif backend == "llvm-cpu":
|
||||
decomp_attn = False
|
||||
progress(0.5, desc="Initializing pipeline...")
|
||||
self.sd_pipe = self.turbine_pipe(
|
||||
@@ -172,15 +218,15 @@ class StableDiffusion:
|
||||
max_length=max_length,
|
||||
batch_size=batch_size,
|
||||
num_inference_steps=steps,
|
||||
device=target_backend,
|
||||
iree_target_triple=triple,
|
||||
device=devices,
|
||||
iree_target_triple=targets,
|
||||
ireec_flags=EMPTY_FLAGS,
|
||||
attn_spec=attn_spec,
|
||||
decomp_attn=decomp_attn,
|
||||
pipeline_dir=self.pipeline_dir,
|
||||
external_weights_dir=self.weights_path,
|
||||
external_weights=external_weights,
|
||||
custom_vae=custom_vae,
|
||||
vae_precision=vae_precision,
|
||||
)
|
||||
progress(1, desc="Pipeline initialized!...")
|
||||
gc.collect()
|
||||
@@ -191,15 +237,23 @@ class StableDiffusion:
|
||||
adapters,
|
||||
embeddings,
|
||||
is_img2img,
|
||||
compiled_pipeline,
|
||||
compiled_pipeline = False,
|
||||
cpu_scheduling=False,
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
progress(0, desc="Preparing models...")
|
||||
|
||||
pipe_map = copy.deepcopy(self.model_map)
|
||||
if compiled_pipeline and self.is_sdxl:
|
||||
pipe_map.pop("scheduler")
|
||||
pipe_map.pop("unet")
|
||||
pipe_map["scheduled_unet"] = None
|
||||
pipe_map["full_pipeline"] = None
|
||||
if cpu_scheduling:
|
||||
pipe_map.pop("scheduler")
|
||||
self.is_img2img = False
|
||||
mlirs = copy.deepcopy(self.model_map)
|
||||
vmfbs = copy.deepcopy(self.model_map)
|
||||
weights = copy.deepcopy(self.model_map)
|
||||
mlirs = copy.deepcopy(pipe_map)
|
||||
vmfbs = copy.deepcopy(pipe_map)
|
||||
weights = copy.deepcopy(pipe_map)
|
||||
if not self.is_sdxl:
|
||||
compiled_pipeline = False
|
||||
self.compiled_pipeline = compiled_pipeline
|
||||
@@ -260,7 +314,7 @@ class StableDiffusion:
|
||||
progress(0.75, desc=f"Loading models and weights...")
|
||||
|
||||
self.sd_pipe.load_pipeline(
|
||||
vmfbs, weights, self.rt_device, self.compiled_pipeline
|
||||
vmfbs, weights, self.compiled_pipeline
|
||||
)
|
||||
progress(1, desc="Pipeline loaded! Generating images...")
|
||||
return
|
||||
@@ -324,7 +378,7 @@ def shark_sd_fn_dict_input(sd_kwargs: dict, *, progress=gr.Progress()):
|
||||
)
|
||||
return None, ""
|
||||
if sd_kwargs["target_triple"] == "":
|
||||
if not parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2]:
|
||||
if not parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[1]:
|
||||
gr.Warning(
|
||||
"Target device architecture could not be inferred. Please specify a target triple, e.g. 'gfx1100' for a Radeon 7900xtx."
|
||||
)
|
||||
@@ -359,6 +413,9 @@ def shark_sd_fn(
|
||||
controlnets: dict,
|
||||
embeddings: dict,
|
||||
seed_increment: str | int = 1,
|
||||
clip_device: str = None,
|
||||
vae_device: str = None,
|
||||
vae_precision: str = None,
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
sd_kwargs = locals()
|
||||
@@ -398,7 +455,8 @@ def shark_sd_fn(
|
||||
control_mode = controlnets["control_mode"]
|
||||
for i in controlnets["hint"]:
|
||||
hints.append[i]
|
||||
|
||||
if not vae_precision:
|
||||
vae_precision = precision
|
||||
submit_pipe_kwargs = {
|
||||
"base_model_id": base_model_id,
|
||||
"height": height,
|
||||
@@ -406,6 +464,8 @@ def shark_sd_fn(
|
||||
"batch_size": batch_size,
|
||||
"precision": precision,
|
||||
"device": device,
|
||||
"clip_device": clip_device,
|
||||
"vae_device": vae_device,
|
||||
"target_triple": target_triple,
|
||||
"custom_vae": custom_vae,
|
||||
"num_loras": num_loras,
|
||||
@@ -413,6 +473,7 @@ def shark_sd_fn(
|
||||
"is_controlled": is_controlled,
|
||||
"steps": steps,
|
||||
"scheduler": scheduler,
|
||||
"vae_precision": vae_precision,
|
||||
}
|
||||
submit_prep_kwargs = {
|
||||
"custom_weights": custom_weights,
|
||||
@@ -485,7 +546,7 @@ def shark_sd_fn(
|
||||
sd_kwargs,
|
||||
)
|
||||
generated_imgs.extend(out_imgs)
|
||||
|
||||
breakpoint()
|
||||
yield generated_imgs, status_label(
|
||||
"Stable Diffusion", current_batch + 1, batch_count, batch_size
|
||||
)
|
||||
|
||||
@@ -12,6 +12,18 @@ from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from cpuinfo import get_cpu_info
|
||||
|
||||
|
||||
_IREE_DEVICE_MAP = {
|
||||
"cpu": "local-task",
|
||||
"cpu-task": "local-task",
|
||||
"cpu-sync": "local-sync",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
"metal": "metal",
|
||||
"rocm": "rocm",
|
||||
"hip": "hip",
|
||||
"intel-gpu": "level_zero",
|
||||
}
|
||||
|
||||
def iree_device_map(device):
|
||||
uri_parts = device.split("://", 2)
|
||||
iree_driver = (
|
||||
@@ -31,25 +43,6 @@ def get_supported_device_list():
|
||||
return list(_IREE_DEVICE_MAP.keys())
|
||||
|
||||
|
||||
_IREE_DEVICE_MAP = {
|
||||
"cpu": "local-task",
|
||||
"cpu-task": "local-task",
|
||||
"cpu-sync": "local-sync",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
"metal": "metal",
|
||||
"rocm": "rocm",
|
||||
"hip": "hip",
|
||||
"intel-gpu": "level_zero",
|
||||
}
|
||||
|
||||
|
||||
def iree_target_map(device):
|
||||
if "://" in device:
|
||||
device = device.split("://")[0]
|
||||
return _IREE_TARGET_MAP[device] if device in _IREE_TARGET_MAP else device
|
||||
|
||||
|
||||
_IREE_TARGET_MAP = {
|
||||
"cpu": "llvm-cpu",
|
||||
"cpu-task": "llvm-cpu",
|
||||
@@ -63,9 +56,13 @@ _IREE_TARGET_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def iree_target_map(device):
|
||||
if "://" in device:
|
||||
device = device.split("://")[0]
|
||||
return _IREE_TARGET_MAP[device] if device in _IREE_TARGET_MAP else device
|
||||
|
||||
|
||||
def get_available_devices():
|
||||
return ['rocm', 'cpu']
|
||||
def get_devices_by_name(driver_name):
|
||||
|
||||
device_list = []
|
||||
@@ -163,7 +160,6 @@ def clean_device_info(raw_device):
|
||||
return device, device_id
|
||||
|
||||
def parse_device(device_str, target_override=""):
|
||||
|
||||
rt_driver, device_id = clean_device_info(device_str)
|
||||
target_backend = iree_target_map(rt_driver)
|
||||
if device_id:
|
||||
@@ -174,19 +170,19 @@ def parse_device(device_str, target_override=""):
|
||||
if target_override:
|
||||
if "cpu" in device_str:
|
||||
rt_device = "local-task"
|
||||
return target_backend, rt_device, target_override
|
||||
return target_backend, target_override
|
||||
match target_backend:
|
||||
case "vulkan-spirv":
|
||||
triple = get_iree_target_triple(device_str)
|
||||
return target_backend, rt_device, triple
|
||||
triple = None #get_iree_target_triple(device_str)
|
||||
return target_backend, triple
|
||||
case "rocm":
|
||||
triple = get_rocm_target_chip(device_str)
|
||||
return target_backend, rt_device, triple
|
||||
return target_backend, triple
|
||||
case "llvm-cpu":
|
||||
if "Ryzen 9" in device_str:
|
||||
return target_backend, "local-task", "znver4"
|
||||
return target_backend, "znver4"
|
||||
else:
|
||||
return "llvm-cpu", "local-task", "x86_64-linux-gnu"
|
||||
return "llvm-cpu", "x86_64-linux-gnu"
|
||||
|
||||
|
||||
def get_rocm_target_chip(device_str):
|
||||
@@ -223,68 +219,4 @@ def get_all_devices(driver_name):
|
||||
device_list_src = driver.query_available_devices()
|
||||
device_list_src.sort(key=lambda d: d["path"])
|
||||
del driver
|
||||
return device_list_src
|
||||
|
||||
|
||||
# def get_device_mapping(driver, key_combination=3):
|
||||
# """This method ensures consistent device ordering when choosing
|
||||
# specific devices for execution
|
||||
# Args:
|
||||
# driver (str): execution driver (vulkan, cuda, rocm, etc)
|
||||
# key_combination (int, optional): choice for mapping value for
|
||||
# device name.
|
||||
# 1 : path
|
||||
# 2 : name
|
||||
# 3 : (name, path)
|
||||
# Defaults to 3.
|
||||
# Returns:
|
||||
# dict: map to possible device names user can input mapped to desired
|
||||
# combination of name/path.
|
||||
# """
|
||||
|
||||
# driver = iree_device_map(driver)
|
||||
# device_list = get_all_devices(driver)
|
||||
# device_map = dict()
|
||||
|
||||
# def get_output_value(dev_dict):
|
||||
# if key_combination == 1:
|
||||
# return f"{driver}://{dev_dict['path']}"
|
||||
# if key_combination == 2:
|
||||
# return dev_dict["name"]
|
||||
# if key_combination == 3:
|
||||
# return dev_dict["name"], f"{driver}://{dev_dict['path']}"
|
||||
|
||||
# # mapping driver name to default device (driver://0)
|
||||
# device_map[f"{driver}"] = get_output_value(device_list[0])
|
||||
# for i, device in enumerate(device_list):
|
||||
# # mapping with index
|
||||
# device_map[f"{driver}://{i}"] = get_output_value(device)
|
||||
# # mapping with full path
|
||||
# device_map[f"{driver}://{device['path']}"] = get_output_value(device)
|
||||
# return device_map
|
||||
|
||||
|
||||
# def get_opt_flags(model, precision="fp16"):
|
||||
# iree_flags = []
|
||||
# if len(cmd_opts.iree_vulkan_target_triple) > 0:
|
||||
# iree_flags.append(
|
||||
# f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}"
|
||||
# )
|
||||
# if "rocm" in cmd_opts.device:
|
||||
# from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
# rocm_args = get_iree_rocm_args()
|
||||
# iree_flags.extend(rocm_args)
|
||||
# if cmd_opts.iree_constant_folding == False:
|
||||
# iree_flags.append("--iree-opt-const-expr-hoisting=False")
|
||||
# iree_flags.append(
|
||||
# "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
|
||||
# )
|
||||
# if cmd_opts.data_tiling == False:
|
||||
# iree_flags.append("--iree-opt-data-tiling=False")
|
||||
|
||||
# if "vae" not in model:
|
||||
# # Due to lack of support for multi-reduce, we always collapse reduction
|
||||
# # dims before dispatch formation right now.
|
||||
# iree_flags += ["--iree-flow-collapse-reduction-dims"]
|
||||
# return iree_flags
|
||||
return device_list_src
|
||||
@@ -89,8 +89,13 @@ def save_output_img(output_img, img_seed, extra_info=None):
|
||||
f"VAE: {img_vae}, "
|
||||
f"LoRA: {img_loras}",
|
||||
)
|
||||
|
||||
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
|
||||
if isinstance(output_img, list):
|
||||
for i, img in enumerate(output_img):
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}_{i}.png")
|
||||
img.save(out_img_path, "PNG", pnginfo=pngInfo)
|
||||
else:
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
|
||||
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
|
||||
|
||||
if cmd_opts.output_img_format not in ["png", "jpg"]:
|
||||
print(
|
||||
|
||||
@@ -52,6 +52,7 @@ sd_default_models = [
|
||||
# "stabilityai/stable-diffusion-2-1",
|
||||
# "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"stabilityai/sdxl-turbo",
|
||||
"stabilityai/stable-diffusion-3-medium-diffusers",
|
||||
]
|
||||
sd_default_models.extend(get_checkpoints(model_type="scripts"))
|
||||
|
||||
@@ -315,7 +316,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
label="Device",
|
||||
value=init_config["device"] if init_config["device"] else "rocm",
|
||||
value=init_config["device"] if init_config["device"] else global_obj.get_device_list()[0],
|
||||
choices=global_obj.get_device_list(),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
@@ -347,8 +348,8 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
value=512,
|
||||
step=512,
|
||||
label="\U00002195\U0000FE0F Height",
|
||||
interactive=False, # DEMO
|
||||
visible=False, # DEMO
|
||||
interactive=True, # DEMO
|
||||
visible=True, # DEMO
|
||||
)
|
||||
width = gr.Slider(
|
||||
512,
|
||||
@@ -356,8 +357,8 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
value=512,
|
||||
step=512,
|
||||
label="\U00002194\U0000FE0F Width",
|
||||
interactive=False, # DEMO
|
||||
visible=False, # DEMO
|
||||
interactive=True, # DEMO
|
||||
visible=True, # DEMO
|
||||
)
|
||||
|
||||
with gr.Accordion(
|
||||
|
||||
@@ -6,11 +6,13 @@ setuptools
|
||||
wheel
|
||||
|
||||
|
||||
torch==2.3.0
|
||||
torch>=2.3.0
|
||||
shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main
|
||||
turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@ean-unify-sd#subdirectory=models
|
||||
diffusers @ git+https://github.com/nod-ai/diffusers@v0.24.0-release
|
||||
diffusers @ git+https://github.com/nod-ai/diffusers@0.29.0.dev0-shark
|
||||
brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b
|
||||
peft
|
||||
sentencepiece
|
||||
|
||||
# SHARK Runner
|
||||
tqdm
|
||||
|
||||
@@ -89,7 +89,7 @@ else {python -m venv .\shark.venv\}
|
||||
python -m pip install --upgrade pip
|
||||
pip install wheel
|
||||
pip install --pre -r requirements.txt
|
||||
pip install https://github.com/nod-ai/SRT/releases/download/candidate-20240602.283/iree_compiler-20240602.283-cp311-cp311-win_amd64.whl https://github.com/nod-ai/SRT/releases/download/candidate-20240602.283/iree_runtime-20240602.283-cp311-cp311-win_amd64.whl
|
||||
pip install https://github.com/nod-ai/SRT/releases/download/candidate-20240617.289/iree_compiler-20240617.289-cp311-cp311-win_amd64.whl https://github.com/nod-ai/SRT/releases/download/candidate-20240617.289/iree_runtime-20240617.289-cp311-cp311-win_amd64.whl
|
||||
pip install -e .
|
||||
|
||||
Write-Host "Source your venv with ./shark.venv/Scripts/activate"
|
||||
|
||||
Reference in New Issue
Block a user