Compare commits

...

2 Commits

Author SHA1 Message Date
Ean Garvey
c7bd54de23 Small fixes to igpu, SDXL-turbo 2024-05-30 00:54:12 -05:00
Ean Garvey
1adfb3b24b Add igpu and custom triple support. 2024-05-29 13:59:13 -05:00
4 changed files with 113 additions and 82 deletions

View File

@@ -73,10 +73,12 @@ class StableDiffusion:
scheduler: str,
precision: str,
device: str,
target_triple: str = None,
custom_vae: str = None,
num_loras: int = 0,
import_ir: bool = True,
is_controlled: bool = False,
external_weights: str = "safetensors",
):
self.precision = precision
self.compiled_pipeline = False
@@ -89,9 +91,8 @@ class StableDiffusion:
else:
self.turbine_pipe = SharkSDPipeline
self.model_map = EMPTY_SD_MAP
external_weights = "safetensors"
max_length = 64
target_backend, self.rt_device, triple = parse_device(device)
target_backend, self.rt_device, triple = parse_device(device, target_triple)
pipe_id_list = [
safe_name(base_model_id),
str(batch_size),
@@ -121,9 +122,12 @@ class StableDiffusion:
if triple in ["gfx940", "gfx942", "gfx90a"]:
decomp_attn = False
attn_spec = "mfma"
elif triple in ["gfx1100", "gfx1103"]:
elif triple in ["gfx1100", "gfx1103", "gfx1150"]:
decomp_attn = False
attn_spec = "wmma"
if triple in ["gfx1103", "gfx1150"]:
# external weights have issues on igpu
external_weights = None
elif target_backend == "llvm-cpu":
decomp_attn = False
@@ -273,6 +277,7 @@ def shark_sd_fn(
custom_vae: str,
precision: str,
device: str,
target_triple: str,
ondemand: bool,
repeatable_seeds: bool,
resample_type: str,
@@ -284,8 +289,6 @@ def shark_sd_fn(
sd_init_image = [sd_init_image]
is_img2img = True if sd_init_image[0] is not None else False
print("\n[LOG] Performing Stable Diffusion Pipeline setup...")
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
import apps.shark_studio.web.utils.globals as global_obj
@@ -326,6 +329,7 @@ def shark_sd_fn(
"batch_size": batch_size,
"precision": precision,
"device": device,
"target_triple": target_triple,
"custom_vae": custom_vae,
"num_loras": num_loras,
"import_ir": import_ir,

View File

@@ -52,6 +52,13 @@ def get_available_devices():
set_iree_runtime_flags()
available_devices = []
rocm_devices = get_devices_by_name("rocm")
available_devices.extend(rocm_devices)
cpu_device = get_devices_by_name("cpu-sync")
available_devices.extend(cpu_device)
cpu_device = get_devices_by_name("cpu-task")
available_devices.extend(cpu_device)
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
)
@@ -64,19 +71,27 @@ def get_available_devices():
id += 1
if id != 0:
print(f"vulkan devices are available.")
available_devices.extend(vulkan_devices)
metal_devices = get_devices_by_name("metal")
available_devices.extend(metal_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
rocm_devices = get_devices_by_name("rocm")
available_devices.extend(rocm_devices)
hip_devices = get_devices_by_name("hip")
available_devices.extend(hip_devices)
cpu_device = get_devices_by_name("cpu-sync")
available_devices.extend(cpu_device)
cpu_device = get_devices_by_name("cpu-task")
available_devices.extend(cpu_device)
for idx, device_str in enumerate(available_devices):
if "AMD Radeon(TM) Graphics =>" in device_str:
igpu_id_candidates = [
x.split("w/")[-1].split("=>")[0]
for x in available_devices
if "M Graphics" in x
]
for igpu_name in igpu_id_candidates:
if igpu_name:
available_devices[idx] = device_str.replace(
"AMD Radeon(TM) Graphics", igpu_name
)
break
return available_devices
@@ -129,7 +144,7 @@ def set_iree_runtime_flags():
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
def parse_device(device_str):
def parse_device(device_str, target_override=""):
from shark.iree_utils.compile_utils import (
clean_device_info,
get_iree_target_triple,
@@ -143,6 +158,8 @@ def parse_device(device_str):
else:
rt_device = rt_driver
if target_override:
return target_backend, rt_device, target_override
match target_backend:
case "vulkan-spirv":
triple = get_iree_target_triple(device_str)
@@ -168,6 +185,7 @@ def get_rocm_target_chip(device_str):
"MI100": "gfx908",
"MI50": "gfx906",
"MI60": "gfx906",
"780M": "gfx1103",
}
for key in rocm_chip_map:
if key in device_str:

View File

@@ -24,47 +24,47 @@ def get_schedulers(model_id):
model_id,
subfolder="scheduler",
)
schedulers["DDPM"] = DDPMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["DDIM"] = DDIMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["LCMScheduler"] = LCMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["DPMSolverMultistep"] = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", algorithm_type="dpmsolver"
)
schedulers["DPMSolverMultistep++"] = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", algorithm_type="dpmsolver++"
)
schedulers["DPMSolverMultistepKarras"] = (
DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
use_karras_sigmas=True,
)
)
schedulers["DPMSolverMultistepKarras++"] = (
DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
algorithm_type="dpmsolver++",
use_karras_sigmas=True,
)
)
# schedulers["DDPM"] = DDPMScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["DDIM"] = DDIMScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["LCMScheduler"] = LCMScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["DPMSolverMultistep"] = DPMSolverMultistepScheduler.from_pretrained(
# model_id, subfolder="scheduler", algorithm_type="dpmsolver"
# )
# schedulers["DPMSolverMultistep++"] = DPMSolverMultistepScheduler.from_pretrained(
# model_id, subfolder="scheduler", algorithm_type="dpmsolver++"
# )
# schedulers["DPMSolverMultistepKarras"] = (
# DPMSolverMultistepScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# use_karras_sigmas=True,
# )
# )
# schedulers["DPMSolverMultistepKarras++"] = (
# DPMSolverMultistepScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# algorithm_type="dpmsolver++",
# use_karras_sigmas=True,
# )
# )
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
@@ -75,24 +75,24 @@ def get_schedulers(model_id):
subfolder="scheduler",
)
)
schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["DPMSolverSinglestep"] = DPMSolverSinglestepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["KDPM2AncestralDiscrete"] = (
KDPM2AncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
)
schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
# schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["DPMSolverSinglestep"] = DPMSolverSinglestepScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["KDPM2AncestralDiscrete"] = (
# KDPM2AncestralDiscreteScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# )
# schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
return schedulers
@@ -102,17 +102,17 @@ def export_scheduler_model(model):
scheduler_model_map = {
"PNDM": export_scheduler_model("PNDMScheduler"),
"DPMSolverSDE": export_scheduler_model("DpmSolverSDEScheduler"),
# "DPMSolverSDE": export_scheduler_model("DpmSolverSDEScheduler"),
"EulerDiscrete": export_scheduler_model("EulerDiscreteScheduler"),
"EulerAncestralDiscrete": export_scheduler_model("EulerAncestralDiscreteScheduler"),
"LCM": export_scheduler_model("LCMScheduler"),
"LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"),
"DDPM": export_scheduler_model("DDPMScheduler"),
"DDIM": export_scheduler_model("DDIMScheduler"),
"DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"),
"KDPM2Discrete": export_scheduler_model("KDPM2DiscreteScheduler"),
"DEISMultistep": export_scheduler_model("DEISMultistepScheduler"),
"DPMSolverSinglestep": export_scheduler_model("DPMSolverSingleStepScheduler"),
"KDPM2AncestralDiscrete": export_scheduler_model("KDPM2AncestralDiscreteScheduler"),
"HeunDiscrete": export_scheduler_model("HeunDiscreteScheduler"),
# "LCM": export_scheduler_model("LCMScheduler"),
# "LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"),
# "DDPM": export_scheduler_model("DDPMScheduler"),
# "DDIM": export_scheduler_model("DDIMScheduler"),
# "DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"),
# "KDPM2Discrete": export_scheduler_model("KDPM2DiscreteScheduler"),
# "DEISMultistep": export_scheduler_model("DEISMultistepScheduler"),
# "DPMSolverSinglestep": export_scheduler_model("DPMSolverSingleStepScheduler"),
# "KDPM2AncestralDiscrete": export_scheduler_model("KDPM2AncestralDiscreteScheduler"),
# "HeunDiscrete": export_scheduler_model("HeunDiscreteScheduler"),
}

View File

@@ -117,6 +117,7 @@ def pull_sd_configs(
custom_vae,
precision,
device,
target_triple,
ondemand,
repeatable_seeds,
resample_type,
@@ -175,6 +176,7 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
sd_json["custom_vae"],
sd_json["precision"],
sd_json["device"],
sd_json["target_triple"],
sd_json["ondemand"],
sd_json["repeatable_seeds"],
sd_json["resample_type"],
@@ -253,6 +255,11 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
choices=global_obj.get_device_list(),
allow_custom_value=False,
)
target_triple = gr.Textbox(
elem_id="triple",
label="Architecture",
value="",
)
with gr.Row():
ondemand = gr.Checkbox(
value=cmd_opts.lowvram,
@@ -691,6 +698,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
custom_vae,
precision,
device,
target_triple,
ondemand,
repeatable_seeds,
resample_type,
@@ -730,6 +738,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
custom_vae,
precision,
device,
target_triple,
ondemand,
repeatable_seeds,
resample_type,