Add igpu and custom triple support. (#2148)

This commit is contained in:
Ean Garvey
2024-05-29 17:39:36 -05:00
committed by GitHub
parent 2074df40ad
commit 13e1d8d98a
3 changed files with 32 additions and 2 deletions

View File

@@ -73,6 +73,7 @@ class StableDiffusion:
scheduler: str,
precision: str,
device: str,
target_triple: str = None,
custom_vae: str = None,
num_loras: int = 0,
import_ir: bool = True,
@@ -91,7 +92,7 @@ class StableDiffusion:
self.model_map = EMPTY_SD_MAP
external_weights = "safetensors"
max_length = 64
target_backend, self.rt_device, triple = parse_device(device)
target_backend, self.rt_device, triple = parse_device(device, target_triple)
pipe_id_list = [
safe_name(base_model_id),
str(batch_size),
@@ -273,6 +274,7 @@ def shark_sd_fn(
custom_vae: str,
precision: str,
device: str,
target_triple: str,
ondemand: bool,
repeatable_seeds: bool,
resample_type: str,
@@ -326,6 +328,7 @@ def shark_sd_fn(
"batch_size": batch_size,
"precision": precision,
"device": device,
"target_triple": target_triple,
"custom_vae": custom_vae,
"num_loras": num_loras,
"import_ir": import_ir,

View File

@@ -77,6 +77,21 @@ def get_available_devices():
available_devices.extend(cpu_device)
cpu_device = get_devices_by_name("cpu-task")
available_devices.extend(cpu_device)
print(available_devices)
for idx, device_str in enumerate(available_devices):
if "AMD Radeon(TM) Graphics =>" in device_str:
igpu_id_candidates = [
x.split("w/")[-1].split("=>")[0]
for x in available_devices
if "M Graphics" in x
]
for igpu_name in igpu_id_candidates:
if igpu_name:
print(f"Found iGPU: {igpu_name} for {device_str}")
available_devices[idx] = device_str.replace(
"AMD Radeon(TM) Graphics", f"AMD iGPU: {igpu_name}"
)
break
return available_devices
@@ -129,7 +144,7 @@ def set_iree_runtime_flags():
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
def parse_device(device_str):
def parse_device(device_str, target_override=""):
from shark.iree_utils.compile_utils import (
clean_device_info,
get_iree_target_triple,
@@ -143,6 +158,8 @@ def parse_device(device_str):
else:
rt_device = rt_driver
if target_override:
return target_backend, rt_device, target_override
match target_backend:
case "vulkan-spirv":
triple = get_iree_target_triple(device_str)
@@ -168,6 +185,7 @@ def get_rocm_target_chip(device_str):
"MI100": "gfx908",
"MI50": "gfx906",
"MI60": "gfx906",
"780M": "gfx1103",
}
for key in rocm_chip_map:
if key in device_str:

View File

@@ -117,6 +117,7 @@ def pull_sd_configs(
custom_vae,
precision,
device,
target_triple,
ondemand,
repeatable_seeds,
resample_type,
@@ -175,6 +176,7 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
sd_json["custom_vae"],
sd_json["precision"],
sd_json["device"],
sd_json["target_triple"],
sd_json["ondemand"],
sd_json["repeatable_seeds"],
sd_json["resample_type"],
@@ -253,6 +255,11 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
choices=global_obj.get_device_list(),
allow_custom_value=False,
)
target_triple = gr.Textbox(
elem_id="triple",
label="Architecture",
value="",
)
with gr.Row():
ondemand = gr.Checkbox(
value=cmd_opts.lowvram,
@@ -691,6 +698,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
custom_vae,
precision,
device,
target_triple,
ondemand,
repeatable_seeds,
resample_type,
@@ -730,6 +738,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
custom_vae,
precision,
device,
target_triple,
ondemand,
repeatable_seeds,
resample_type,