Shark Studio SDXL support, HIP driver support, simpler device info, small fixes

This commit is contained in:
Ean Garvey
2024-04-24 04:01:28 -05:00
parent fd07cae991
commit 0f9930096a
9 changed files with 163 additions and 401 deletions

5
.gitignore vendored
View File

@@ -188,6 +188,11 @@ variants.json
# models folder
apps/stable_diffusion/web/models/
# model artifacts (SHARK)
*.tempfile
*.mlir
*.vmfb
# Stencil annotators.
stencil_annotator/

View File

@@ -9,14 +9,15 @@ from tqdm.auto import tqdm
from pathlib import Path
from random import randint
from turbine_models.custom_models.sd_inference import clip, unet, vae
from turbine_models.custom_models.sdxl_inference import sdxl_compiled_pipeline
from apps.shark_studio.api.controlnet import control_adapter_map
from apps.shark_studio.api.utils import parse_device
from apps.shark_studio.web.utils.state import status_label
from apps.shark_studio.web.utils.file_utils import (
safe_name,
get_resource_path,
get_checkpoints_path,
)
from apps.shark_studio.modules.pipeline import SharkPipelineBase
from apps.shark_studio.modules.schedulers import get_schedulers
from apps.shark_studio.modules.prompt_encoding import (
get_weighted_text_embeddings,
@@ -32,8 +33,6 @@ from apps.shark_studio.modules.ckpt_processing import (
preprocessCKPT,
process_custom_pipe_weights,
)
from transformers import CLIPTokenizer
from diffusers.image_processor import VaeImageProcessor
sd_model_map = {
"clip": {
@@ -47,8 +46,15 @@ sd_model_map = {
},
}
EMPTY_FLAGS = {
"clip": None,
"unet": None,
"vae": None,
"pipeline": None,
}
class StableDiffusion(SharkPipelineBase):
class StableDiffusion:
# This class is responsible for executing image generation and creating
# /managing a set of compiled modules to run Stable Diffusion. The init
# aims to be as general as possible, and the class will infer and compile
@@ -61,6 +67,8 @@ class StableDiffusion(SharkPipelineBase):
height: int,
width: int,
batch_size: int,
steps: int,
scheduler: str,
precision: str,
device: str,
custom_vae: str = None,
@@ -69,58 +77,18 @@ class StableDiffusion(SharkPipelineBase):
is_controlled: bool = False,
hf_auth_token=None,
):
self.model_max_length = 77
self.batch_size = batch_size
self.precision = precision
self.dtype = torch.float16 if precision == "fp16" else torch.float32
self.height = height
self.width = width
self.scheduler_obj = {}
static_kwargs = {
"pipe": {
"external_weights": "safetensors",
},
"clip": {"hf_model_name": base_model_id},
"unet": {
"hf_model_name": base_model_id,
"unet_model": unet.UnetModel(hf_model_name=base_model_id),
"batch_size": batch_size,
# "is_controlled": is_controlled,
# "num_loras": num_loras,
"height": height,
"width": width,
"precision": precision,
"max_length": self.model_max_length,
},
"vae_encode": {
"hf_model_name": base_model_id,
"vae_model": vae.VaeModel(
hf_model_name=custom_vae if custom_vae else base_model_id,
),
"batch_size": batch_size,
"height": height,
"width": width,
"precision": precision,
},
"vae_decode": {
"hf_model_name": base_model_id,
"vae_model": vae.VaeModel(
hf_model_name=custom_vae if custom_vae else base_model_id,
),
"batch_size": batch_size,
"height": height,
"width": width,
"precision": precision,
},
}
super().__init__(sd_model_map, base_model_id, static_kwargs, device, import_ir)
self.compiled_pipeline = False
self.base_model_id = base_model_id
external_weights = "safetensors"
max_length = 64
target_backend, self.rt_device, triple = parse_device(device)
pipe_id_list = [
safe_name(base_model_id),
str(batch_size),
str(self.model_max_length),
str(max_length),
f"{str(height)}x{str(width)}",
precision,
self.device,
triple,
]
if num_loras > 0:
pipe_id_list.append(str(num_loras) + "lora")
@@ -129,227 +97,67 @@ class StableDiffusion(SharkPipelineBase):
if custom_vae:
pipe_id_list.append(custom_vae)
self.pipe_id = "_".join(pipe_id_list)
self.weights_path = os.path.join(
get_checkpoints_path(), safe_name(self.base_model_id)
)
if not os.path.exists(self.weights_path):
os.mkdir(self.weights_path)
self.sd_pipe = sdxl_compiled_pipeline.SharkSDXLPipeline(
hf_model_name=base_model_id,
scheduler_id=scheduler,
height=height,
width=width,
precision=precision,
max_length=max_length,
batch_size=batch_size,
num_inference_steps=steps,
device=target_backend,
iree_target_triple=triple,
ireec_flags=EMPTY_FLAGS,
attn_spec=None,
decomp_attn=True if "gfx9" not in triple else False,
pipeline_dir=self.pipe_id,
external_weights_dir=self.weights_path,
external_weights=external_weights,
)
print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.")
del static_kwargs
gc.collect()
def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img):
print(f"\n[LOG] Preparing pipeline...")
self.is_img2img = is_img2img
self.schedulers = get_schedulers(self.base_model_id)
self.weights_path = os.path.join(
get_checkpoints_path(), self.safe_name(self.base_model_id)
)
if not os.path.exists(self.weights_path):
os.mkdir(self.weights_path)
for model in adapters:
self.model_map[model] = adapters[model]
for submodel in self.static_kwargs:
if custom_weights:
custom_weights_params, _ = process_custom_pipe_weights(custom_weights)
if submodel not in ["clip", "clip2"]:
self.static_kwargs[submodel][
"external_weights"
] = custom_weights_params
else:
self.static_kwargs[submodel]["external_weight_path"] = os.path.join(
self.weights_path, submodel + ".safetensors"
)
else:
self.static_kwargs[submodel]["external_weight_path"] = os.path.join(
self.weights_path, submodel + ".safetensors"
)
self.get_compiled_map(pipe_id=self.pipe_id)
print("\n[LOG] Pipeline successfully prepared for runtime.")
mlirs = {
"prompt_encoder": None,
"scheduled_unet": None,
"vae_decode": None,
"pipeline": None,
"full_pipeline": None,
}
vmfbs = {
"prompt_encoder": None,
"scheduled_unet": None,
"vae_decode": None,
"pipeline": None,
"full_pipeline": None,
}
weights = {
"prompt_encoder": None,
"scheduled_unet": None,
"vae_decode": None,
"pipeline": None,
"full_pipeline": None,
}
vmfbs, weights = self.sd_pipe.check_prepared(mlirs, vmfbs, weights, interactive=False)
print(f"\n[LOG] Loading pipeline to device {self.rt_device}.")
self.sd_pipe.load_pipeline(vmfbs, weights, self.rt_device, self.compiled_pipeline)
print("\n[LOG] Pipeline successfully prepared for runtime. Generating images...")
return
def encode_prompts_weight(
self,
prompt,
negative_prompt,
do_classifier_free_guidance=True,
):
# Encodes the prompt into text encoder hidden states.
self.load_submodels(["clip"])
self.tokenizer = CLIPTokenizer.from_pretrained(
self.base_model_id,
subfolder="tokenizer",
)
clip_inf_start = time.time()
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
)
if do_classifier_free_guidance:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
pad = (0, 0) * (len(text_embeddings.shape) - 2)
pad = pad + (
0,
self.static_kwargs["unet"]["max_length"] - text_embeddings.shape[1],
)
text_embeddings = torch.nn.functional.pad(text_embeddings, pad)
# SHARK: Report clip inference time
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_submodels(["clip"])
gc.collect()
print(f"\n[LOG] Clip Inference time (ms) = {clip_inf_time:.3f}")
return text_embeddings.numpy().astype(np.float16)
def prepare_latents(
self,
generator,
num_inference_steps,
image,
strength,
):
noise = torch.randn(
(
self.batch_size,
4,
self.height // 8,
self.width // 8,
),
generator=generator,
dtype=self.dtype,
).to("cpu")
self.scheduler.set_timesteps(num_inference_steps)
if self.is_img2img:
init_timestep = min(
int(num_inference_steps * strength), num_inference_steps
)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start:]
latents = self.encode_image(image)
latents = self.scheduler.add_noise(latents, noise, timesteps[0].repeat(1))
return latents, [timesteps]
else:
self.scheduler.is_scale_input_called = True
latents = noise * self.scheduler.init_noise_sigma
return latents, self.scheduler.timesteps
def encode_image(self, input_image):
self.load_submodels(["vae_encode"])
vae_encode_start = time.time()
latents = self.run("vae_encode", input_image)
vae_inf_time = (time.time() - vae_encode_start) * 1000
if self.ondemand:
self.unload_submodels(["vae_encode"])
print(f"\n[LOG] VAE Encode Inference time (ms): {vae_inf_time:.3f}")
return latents
def produce_img_latents(
self,
latents,
text_embeddings,
guidance_scale,
total_timesteps,
cpu_scheduling,
mask=None,
masked_image_latents=None,
return_all_latents=False,
):
# self.status = SD_STATE_IDLE
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(self.dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
guidance_scale = torch.Tensor([guidance_scale]).to(self.dtype)
self.load_submodels(["unet"])
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(self.dtype).detach().numpy()
latent_model_input = self.scheduler.scale_model_input(latents, t).to(
self.dtype
)
if mask is not None and masked_image_latents is not None:
latent_model_input = torch.cat(
[
torch.from_numpy(np.asarray(latent_model_input)).to(self.dtype),
mask,
masked_image_latents,
],
dim=1,
).to(self.dtype)
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
# Profiling Unet.
# profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.run(
"unet",
[
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
],
)
# end_profiling(profile_device)
if cpu_scheduling:
noise_pred = torch.from_numpy(noise_pred.to_host())
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
else:
latents = self.run("scheduler_step", (noise_pred, t, latents))
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# print(
# f"\n [LOG] step = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
# if self.status == SD_STATE_CANCEL:
# break
if self.ondemand:
self.unload_submodels(["unet"])
gc.collect()
avg_step_time = step_time_sum / len(total_timesteps)
print(f"\n[LOG] Average step time: {avg_step_time}ms/it")
if not return_all_latents:
return latents
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def decode_latents(self, latents, cpu_scheduling=True):
latents_numpy = latents.to(self.dtype)
if cpu_scheduling:
latents_numpy = latents.detach().numpy()
# profile_device = start_profiling(file_path="vae.rdc")
vae_start = time.time()
images = self.run("vae_decode", latents_numpy).to_host()
vae_inf_time = (time.time() - vae_start) * 1000
# end_profiling(profile_device)
print(f"\n[LOG] VAE Inference time (ms): {vae_inf_time:.3f}")
images = torch.from_numpy(images).permute(0, 2, 3, 1).float().numpy()
pil_images = self.image_processor.numpy_to_pil(images)
return pil_images
def generate_images(
self,
prompt,
negative_prompt,
image,
scheduler,
steps,
strength,
guidance_scale,
seed,
@@ -359,69 +167,15 @@ class StableDiffusion(SharkPipelineBase):
control_mode,
hints,
):
# TODO: Batched args
self.image_processor = VaeImageProcessor(do_convert_rgb=True)
self.scheduler = self.schedulers[scheduler]
self.ondemand = ondemand
if self.is_img2img:
image, _ = self.image_processor.preprocess(image, resample_type)
else:
image = None
print("\n[LOG] Generating images...")
batched_args = [
prompt,
negative_prompt,
image,
]
for arg in batched_args:
if not isinstance(arg, list):
arg = [arg] * self.batch_size
if len(arg) < self.batch_size:
arg = arg * self.batch_size
else:
arg = [arg[i] for i in range(self.batch_size)]
text_embeddings = self.encode_prompts_weight(
img = self.sd_pipe.generate_images(
prompt,
negative_prompt,
1,
guidance_scale,
seed,
return_imgs=True,
)
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
init_latents, final_timesteps = self.prepare_latents(
generator=generator,
num_inference_steps=steps,
image=image,
strength=strength,
)
latents = self.produce_img_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=final_timesteps,
cpu_scheduling=True, # until we have schedulers through Turbine
)
# Img latents -> PIL images
all_imgs = []
self.load_submodels(["vae_decode"])
for i in tqdm(range(0, latents.shape[0], self.batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + self.batch_size],
cpu_scheduling=True,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_submodels(["vae_decode"])
return all_imgs
return img
def shark_sd_fn_dict_input(
@@ -516,6 +270,8 @@ def shark_sd_fn(
"num_loras": num_loras,
"import_ir": cmd_opts.import_mlir,
"is_controlled": is_controlled,
"steps": steps,
"scheduler": scheduler,
}
submit_prep_kwargs = {
"custom_weights": custom_weights,
@@ -527,8 +283,6 @@ def shark_sd_fn(
"prompt": prompt,
"negative_prompt": negative_prompt,
"image": sd_init_image,
"steps": steps,
"scheduler": scheduler,
"strength": strength,
"guidance_scale": guidance_scale,
"seed": seed,
@@ -566,9 +320,9 @@ def shark_sd_fn(
for current_batch in range(batch_count):
start_time = time.time()
out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs)
total_time = time.time() - start_time
text_output = f"Total image(s) generation time: {total_time:.4f}sec"
print(f"\n[LOG] {text_output}")
# total_time = time.time() - start_time
# text_output = f"Total image(s) generation time: {total_time:.4f}sec"
#print(f"\n[LOG] {text_output}")
# if global_obj.get_sd_status() == SD_STATE_CANCEL:
# break
# else:
@@ -595,6 +349,9 @@ def view_json_file(file_path):
content = fopen.read()
return content
def safe_name(name):
return name.replace("/", "_").replace("-", "_").replace("\\", "_").replace(".", "_")
if __name__ == "__main__":
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts

View File

@@ -12,11 +12,6 @@ from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from cpuinfo import get_cpu_info
# TODO: migrate these utils to studio
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
get_iree_vulkan_runtime_flags,
)
def get_available_devices():
@@ -49,8 +44,6 @@ def get_available_devices():
device_list.append(f"{device_name} => {driver_name}://{i}")
return device_list
set_iree_runtime_flags()
available_devices = []
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
@@ -71,6 +64,8 @@ def get_available_devices():
available_devices.extend(cuda_devices)
rocm_devices = get_devices_by_name("rocm")
available_devices.extend(rocm_devices)
hip_devices = get_devices_by_name("hip")
available_devices.extend(hip_devices)
cpu_device = get_devices_by_name("cpu-sync")
available_devices.extend(cpu_device)
cpu_device = get_devices_by_name("cpu-task")
@@ -78,54 +73,45 @@ def get_available_devices():
return available_devices
def set_init_device_flags():
if "vulkan" in cmd_opts.device:
# set runtime flags for vulkan.
set_iree_runtime_flags()
def parse_device(device_str):
from shark.iree_utils.compile_utils import clean_device_info, get_iree_target_triple, iree_target_map
rt_driver, device_id = clean_device_info(device_str)
target_backend = iree_target_map(rt_driver)
if device_id:
rt_device = f"{rt_driver}://{device_id}"
else:
rt_device = rt_driver
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
if not cmd_opts.iree_vulkan_target_triple:
triple = get_vulkan_target_triple(device_name)
if triple is not None:
cmd_opts.iree_vulkan_target_triple = triple
print(
f"Found device {device_name}. Using target triple "
f"{cmd_opts.iree_vulkan_target_triple}."
)
elif "cuda" in cmd_opts.device:
cmd_opts.device = "cuda"
elif "metal" in cmd_opts.device:
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
if not cmd_opts.iree_metal_target_platform:
from shark.iree_utils.metal_utils import get_metal_target_triple
triple = get_metal_target_triple(device_name)
if triple is not None:
cmd_opts.iree_metal_target_platform = triple.split("-")[-1]
print(
f"Found device {device_name}. Using target triple "
f"{cmd_opts.iree_metal_target_platform}."
)
elif "cpu" in cmd_opts.device:
cmd_opts.device = "cpu"
match target_backend:
case "vulkan-spirv":
triple = get_iree_target_triple(device_str)
return target_backend, rt_device, triple
case "rocm":
triple = get_rocm_target_chip(device_str)
return target_backend, rt_device, triple
case "cpu":
return "llvm-cpu", "local-task", "x86_64-linux-gnu"
def set_iree_runtime_flags():
# TODO: This function should be device-agnostic and piped properly
# to general runtime driver init.
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
if cmd_opts.enable_rgp:
vulkan_runtime_flags += [
f"--enable_rgp=true",
f"--vulkan_debug_utils=true",
]
if cmd_opts.device_allocator_heap_key:
vulkan_runtime_flags += [
f"--device_allocator=caching:device_local={cmd_opts.device_allocator_heap_key}",
]
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
def get_rocm_target_chip(device_str):
#TODO: Use a data file to map device_str to target chip.
rocm_chip_map = {
"6700": "gfx1031",
"6800": "gfx1030",
"6900": "gfx1030",
"7900": "gfx1100",
"MI300X": "gfx942",
"MI300A": "gfx940",
"MI210": "gfx90a",
"MI250": "gfx90a",
"MI100": "gfx908",
"MI50": "gfx906",
"MI60": "gfx906",
}
for key in rocm_chip_map:
if key in device_str:
return rocm_chip_map[key]
raise AssertionError(f"Device {device_str} not recognized. Please file an issue at https://github.com/nod-ai/SHARK/issues.")
def get_all_devices(driver_name):
"""

View File

@@ -45,11 +45,10 @@ from apps.shark_studio.modules import logger
import apps.shark_studio.web.utils.globals as global_obj
sd_default_models = [
"CompVis/stable-diffusion-v1-4",
"runwayml/stable-diffusion-v1-5",
"stabilityai/stable-diffusion-2-1-base",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-xl-1.0",
"stabilityai/stable-diffusion-xl-base-1.0",
"stabilityai/sdxl-turbo",
]
@@ -286,14 +285,14 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
with gr.Row():
height = gr.Slider(
384,
768,
1024,
value=cmd_opts.height,
step=8,
label="\U00002195\U0000FE0F Height",
)
width = gr.Slider(
384,
768,
1024,
value=cmd_opts.width,
step=8,
label="\U00002194\U0000FE0F Width",

View File

@@ -5,9 +5,10 @@
setuptools
wheel
torch==2.3.0
shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main
turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@main#subdirectory=models
torch>=2.3.0
shark-turbine @ git+https://github.com/nod-ai/SHARK-Turbine.git@ean-sdxl-fixes#subdirectory=core
turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@ean-sdxl-fixes#subdirectory=models
diffusers @ git+https://github.com/nod-ai/diffusers@v0.24.0-release
# SHARK Runner
tqdm

View File

@@ -88,5 +88,7 @@ else {python -m venv .\shark.venv\}
.\shark.venv\Scripts\activate
python -m pip install --upgrade pip
pip install wheel
pip install -r requirements.txt
pip install --pre -r requirements.txt
>>>>>>> 0c904eb7 (Shark Studio SDXL support, HIP driver support, simpler device info, small fixes)
Write-Host "Source your venv with ./shark.venv/Scripts/activate"

View File

@@ -76,6 +76,7 @@ _IREE_DEVICE_MAP = {
"vulkan": "vulkan",
"metal": "metal",
"rocm": "rocm",
"hip": "hip",
"intel-gpu": "level_zero",
}
@@ -94,6 +95,7 @@ _IREE_TARGET_MAP = {
"vulkan": "vulkan-spirv",
"metal": "metal",
"rocm": "rocm",
"hip": "rocm",
"intel-gpu": "opencl-spirv",
}

View File

@@ -62,6 +62,9 @@ def get_iree_device_args(device, extra_args=[]):
from shark.iree_utils.gpu_utils import get_iree_rocm_args
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args)
if device == "hip":
from shark.iree_utils.gpu_utils import get_iree_rocm_args
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args, hip_driver=True)
return []
def get_iree_target_triple(device):

View File

@@ -52,7 +52,7 @@ def check_rocm_device_arch_in_args(extra_args):
return None
def get_rocm_device_arch(device_num=0, extra_args=[]):
def get_rocm_device_arch(device_num=0, extra_args=[], hip_driver=False):
# ROCM Device Arch selection:
# 1 : User given device arch using `--iree-rocm-target-chip` flag
# 2 : Device arch from `iree-run-module --dump_devices=rocm` for device on index <device_num>
@@ -68,15 +68,23 @@ def get_rocm_device_arch(device_num=0, extra_args=[]):
arch_in_device_dump = None
# get rocm arch from iree dump devices
def get_devices_info_from_dump(dump):
def get_devices_info_from_dump(dump, driver):
from os import linesep
dump_clean = list(
filter(
lambda s: "--device=rocm" in s or "gpu-arch-name:" in s,
dump.split(linesep),
if driver == "hip":
dump_clean = list(
filter(
lambda s: "AMD" in s,
dump.split(linesep),
)
)
else:
dump_clean = list(
filter(
lambda s: f"--device={driver}" in s or "gpu-arch-name:" in s,
dump.split(linesep),
)
)
)
arch_pairs = [
(
dump_clean[i].split("=")[1].strip(),
@@ -87,16 +95,17 @@ def get_rocm_device_arch(device_num=0, extra_args=[]):
return arch_pairs
dump_device_info = None
driver = "hip" if hip_driver else "rocm"
try:
dump_device_info = run_cmd(
"iree-run-module --dump_devices=rocm", raise_err=True
"iree-run-module --dump_devices=" + driver, raise_err=True
)
except Exception as e:
print("could not execute `iree-run-module --dump_devices=rocm`")
print("could not execute `iree-run-module --dump_devices=" + driver + "`")
if dump_device_info is not None:
device_num = 0 if device_num is None else device_num
device_arch_pairs = get_devices_info_from_dump(dump_device_info[0])
device_arch_pairs = get_devices_info_from_dump(dump_device_info[0], driver)
if len(device_arch_pairs) > device_num: # can find arch in the list
arch_in_device_dump = device_arch_pairs[device_num][1]
@@ -107,24 +116,22 @@ def get_rocm_device_arch(device_num=0, extra_args=[]):
default_rocm_arch = "gfx1100"
print(
"Did not find ROCm architecture from `--iree-rocm-target-chip` flag"
"\n or from `iree-run-module --dump_devices=rocm` command."
"\n or from `iree-run-module --dump_devices` command."
f"\nUsing {default_rocm_arch} as ROCm arch for compilation."
)
return default_rocm_arch
# Get the default gpu args given the architecture.
def get_iree_rocm_args(device_num=0, extra_args=[]):
def get_iree_rocm_args(device_num=0, extra_args=[], hip_driver=False):
ireert.flags.FUNCTION_INPUT_VALIDATION = False
rocm_flags = ["--iree-rocm-link-bc=true"]
rocm_flags = []
if check_rocm_device_arch_in_args(extra_args) is None:
rocm_arch = get_rocm_device_arch(device_num, extra_args)
rocm_arch = get_rocm_device_arch(device_num, extra_args, hip_driver=hip_driver)
rocm_flags.append(f"--iree-rocm-target-chip={rocm_arch}")
return rocm_flags
# Some constants taken from cuda.h
CUDA_SUCCESS = 0
CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT = 16