mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Shark Studio SDXL support, HIP driver support, simpler device info, small fixes
This commit is contained in:
5
.gitignore
vendored
5
.gitignore
vendored
@@ -188,6 +188,11 @@ variants.json
|
||||
# models folder
|
||||
apps/stable_diffusion/web/models/
|
||||
|
||||
# model artifacts (SHARK)
|
||||
*.tempfile
|
||||
*.mlir
|
||||
*.vmfb
|
||||
|
||||
# Stencil annotators.
|
||||
stencil_annotator/
|
||||
|
||||
|
||||
@@ -9,14 +9,15 @@ from tqdm.auto import tqdm
|
||||
from pathlib import Path
|
||||
from random import randint
|
||||
from turbine_models.custom_models.sd_inference import clip, unet, vae
|
||||
from turbine_models.custom_models.sdxl_inference import sdxl_compiled_pipeline
|
||||
from apps.shark_studio.api.controlnet import control_adapter_map
|
||||
from apps.shark_studio.api.utils import parse_device
|
||||
from apps.shark_studio.web.utils.state import status_label
|
||||
from apps.shark_studio.web.utils.file_utils import (
|
||||
safe_name,
|
||||
get_resource_path,
|
||||
get_checkpoints_path,
|
||||
)
|
||||
from apps.shark_studio.modules.pipeline import SharkPipelineBase
|
||||
from apps.shark_studio.modules.schedulers import get_schedulers
|
||||
from apps.shark_studio.modules.prompt_encoding import (
|
||||
get_weighted_text_embeddings,
|
||||
@@ -32,8 +33,6 @@ from apps.shark_studio.modules.ckpt_processing import (
|
||||
preprocessCKPT,
|
||||
process_custom_pipe_weights,
|
||||
)
|
||||
from transformers import CLIPTokenizer
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
|
||||
sd_model_map = {
|
||||
"clip": {
|
||||
@@ -47,8 +46,15 @@ sd_model_map = {
|
||||
},
|
||||
}
|
||||
|
||||
EMPTY_FLAGS = {
|
||||
"clip": None,
|
||||
"unet": None,
|
||||
"vae": None,
|
||||
"pipeline": None,
|
||||
}
|
||||
|
||||
class StableDiffusion(SharkPipelineBase):
|
||||
|
||||
class StableDiffusion:
|
||||
# This class is responsible for executing image generation and creating
|
||||
# /managing a set of compiled modules to run Stable Diffusion. The init
|
||||
# aims to be as general as possible, and the class will infer and compile
|
||||
@@ -61,6 +67,8 @@ class StableDiffusion(SharkPipelineBase):
|
||||
height: int,
|
||||
width: int,
|
||||
batch_size: int,
|
||||
steps: int,
|
||||
scheduler: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
custom_vae: str = None,
|
||||
@@ -69,58 +77,18 @@ class StableDiffusion(SharkPipelineBase):
|
||||
is_controlled: bool = False,
|
||||
hf_auth_token=None,
|
||||
):
|
||||
self.model_max_length = 77
|
||||
self.batch_size = batch_size
|
||||
self.precision = precision
|
||||
self.dtype = torch.float16 if precision == "fp16" else torch.float32
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.scheduler_obj = {}
|
||||
static_kwargs = {
|
||||
"pipe": {
|
||||
"external_weights": "safetensors",
|
||||
},
|
||||
"clip": {"hf_model_name": base_model_id},
|
||||
"unet": {
|
||||
"hf_model_name": base_model_id,
|
||||
"unet_model": unet.UnetModel(hf_model_name=base_model_id),
|
||||
"batch_size": batch_size,
|
||||
# "is_controlled": is_controlled,
|
||||
# "num_loras": num_loras,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"precision": precision,
|
||||
"max_length": self.model_max_length,
|
||||
},
|
||||
"vae_encode": {
|
||||
"hf_model_name": base_model_id,
|
||||
"vae_model": vae.VaeModel(
|
||||
hf_model_name=custom_vae if custom_vae else base_model_id,
|
||||
),
|
||||
"batch_size": batch_size,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"precision": precision,
|
||||
},
|
||||
"vae_decode": {
|
||||
"hf_model_name": base_model_id,
|
||||
"vae_model": vae.VaeModel(
|
||||
hf_model_name=custom_vae if custom_vae else base_model_id,
|
||||
),
|
||||
"batch_size": batch_size,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"precision": precision,
|
||||
},
|
||||
}
|
||||
super().__init__(sd_model_map, base_model_id, static_kwargs, device, import_ir)
|
||||
self.compiled_pipeline = False
|
||||
self.base_model_id = base_model_id
|
||||
external_weights = "safetensors"
|
||||
max_length = 64
|
||||
target_backend, self.rt_device, triple = parse_device(device)
|
||||
pipe_id_list = [
|
||||
safe_name(base_model_id),
|
||||
str(batch_size),
|
||||
str(self.model_max_length),
|
||||
str(max_length),
|
||||
f"{str(height)}x{str(width)}",
|
||||
precision,
|
||||
self.device,
|
||||
triple,
|
||||
]
|
||||
if num_loras > 0:
|
||||
pipe_id_list.append(str(num_loras) + "lora")
|
||||
@@ -129,227 +97,67 @@ class StableDiffusion(SharkPipelineBase):
|
||||
if custom_vae:
|
||||
pipe_id_list.append(custom_vae)
|
||||
self.pipe_id = "_".join(pipe_id_list)
|
||||
self.weights_path = os.path.join(
|
||||
get_checkpoints_path(), safe_name(self.base_model_id)
|
||||
)
|
||||
if not os.path.exists(self.weights_path):
|
||||
os.mkdir(self.weights_path)
|
||||
self.sd_pipe = sdxl_compiled_pipeline.SharkSDXLPipeline(
|
||||
hf_model_name=base_model_id,
|
||||
scheduler_id=scheduler,
|
||||
height=height,
|
||||
width=width,
|
||||
precision=precision,
|
||||
max_length=max_length,
|
||||
batch_size=batch_size,
|
||||
num_inference_steps=steps,
|
||||
device=target_backend,
|
||||
iree_target_triple=triple,
|
||||
ireec_flags=EMPTY_FLAGS,
|
||||
attn_spec=None,
|
||||
decomp_attn=True if "gfx9" not in triple else False,
|
||||
pipeline_dir=self.pipe_id,
|
||||
external_weights_dir=self.weights_path,
|
||||
external_weights=external_weights,
|
||||
)
|
||||
print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.")
|
||||
del static_kwargs
|
||||
gc.collect()
|
||||
|
||||
def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img):
|
||||
print(f"\n[LOG] Preparing pipeline...")
|
||||
self.is_img2img = is_img2img
|
||||
self.schedulers = get_schedulers(self.base_model_id)
|
||||
|
||||
self.weights_path = os.path.join(
|
||||
get_checkpoints_path(), self.safe_name(self.base_model_id)
|
||||
)
|
||||
if not os.path.exists(self.weights_path):
|
||||
os.mkdir(self.weights_path)
|
||||
|
||||
for model in adapters:
|
||||
self.model_map[model] = adapters[model]
|
||||
|
||||
for submodel in self.static_kwargs:
|
||||
if custom_weights:
|
||||
custom_weights_params, _ = process_custom_pipe_weights(custom_weights)
|
||||
if submodel not in ["clip", "clip2"]:
|
||||
self.static_kwargs[submodel][
|
||||
"external_weights"
|
||||
] = custom_weights_params
|
||||
else:
|
||||
self.static_kwargs[submodel]["external_weight_path"] = os.path.join(
|
||||
self.weights_path, submodel + ".safetensors"
|
||||
)
|
||||
else:
|
||||
self.static_kwargs[submodel]["external_weight_path"] = os.path.join(
|
||||
self.weights_path, submodel + ".safetensors"
|
||||
)
|
||||
|
||||
self.get_compiled_map(pipe_id=self.pipe_id)
|
||||
print("\n[LOG] Pipeline successfully prepared for runtime.")
|
||||
mlirs = {
|
||||
"prompt_encoder": None,
|
||||
"scheduled_unet": None,
|
||||
"vae_decode": None,
|
||||
"pipeline": None,
|
||||
"full_pipeline": None,
|
||||
}
|
||||
vmfbs = {
|
||||
"prompt_encoder": None,
|
||||
"scheduled_unet": None,
|
||||
"vae_decode": None,
|
||||
"pipeline": None,
|
||||
"full_pipeline": None,
|
||||
}
|
||||
weights = {
|
||||
"prompt_encoder": None,
|
||||
"scheduled_unet": None,
|
||||
"vae_decode": None,
|
||||
"pipeline": None,
|
||||
"full_pipeline": None,
|
||||
}
|
||||
vmfbs, weights = self.sd_pipe.check_prepared(mlirs, vmfbs, weights, interactive=False)
|
||||
print(f"\n[LOG] Loading pipeline to device {self.rt_device}.")
|
||||
self.sd_pipe.load_pipeline(vmfbs, weights, self.rt_device, self.compiled_pipeline)
|
||||
print("\n[LOG] Pipeline successfully prepared for runtime. Generating images...")
|
||||
return
|
||||
|
||||
def encode_prompts_weight(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
do_classifier_free_guidance=True,
|
||||
):
|
||||
# Encodes the prompt into text encoder hidden states.
|
||||
self.load_submodels(["clip"])
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(
|
||||
self.base_model_id,
|
||||
subfolder="tokenizer",
|
||||
)
|
||||
clip_inf_start = time.time()
|
||||
|
||||
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
|
||||
pipe=self,
|
||||
prompt=prompt,
|
||||
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
pad = (0, 0) * (len(text_embeddings.shape) - 2)
|
||||
pad = pad + (
|
||||
0,
|
||||
self.static_kwargs["unet"]["max_length"] - text_embeddings.shape[1],
|
||||
)
|
||||
text_embeddings = torch.nn.functional.pad(text_embeddings, pad)
|
||||
|
||||
# SHARK: Report clip inference time
|
||||
clip_inf_time = (time.time() - clip_inf_start) * 1000
|
||||
if self.ondemand:
|
||||
self.unload_submodels(["clip"])
|
||||
gc.collect()
|
||||
print(f"\n[LOG] Clip Inference time (ms) = {clip_inf_time:.3f}")
|
||||
|
||||
return text_embeddings.numpy().astype(np.float16)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
image,
|
||||
strength,
|
||||
):
|
||||
noise = torch.randn(
|
||||
(
|
||||
self.batch_size,
|
||||
4,
|
||||
self.height // 8,
|
||||
self.width // 8,
|
||||
),
|
||||
generator=generator,
|
||||
dtype=self.dtype,
|
||||
).to("cpu")
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
if self.is_img2img:
|
||||
init_timestep = min(
|
||||
int(num_inference_steps * strength), num_inference_steps
|
||||
)
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
latents = self.encode_image(image)
|
||||
latents = self.scheduler.add_noise(latents, noise, timesteps[0].repeat(1))
|
||||
return latents, [timesteps]
|
||||
else:
|
||||
self.scheduler.is_scale_input_called = True
|
||||
latents = noise * self.scheduler.init_noise_sigma
|
||||
return latents, self.scheduler.timesteps
|
||||
|
||||
def encode_image(self, input_image):
|
||||
self.load_submodels(["vae_encode"])
|
||||
vae_encode_start = time.time()
|
||||
latents = self.run("vae_encode", input_image)
|
||||
vae_inf_time = (time.time() - vae_encode_start) * 1000
|
||||
if self.ondemand:
|
||||
self.unload_submodels(["vae_encode"])
|
||||
print(f"\n[LOG] VAE Encode Inference time (ms): {vae_inf_time:.3f}")
|
||||
|
||||
return latents
|
||||
|
||||
def produce_img_latents(
|
||||
self,
|
||||
latents,
|
||||
text_embeddings,
|
||||
guidance_scale,
|
||||
total_timesteps,
|
||||
cpu_scheduling,
|
||||
mask=None,
|
||||
masked_image_latents=None,
|
||||
return_all_latents=False,
|
||||
):
|
||||
# self.status = SD_STATE_IDLE
|
||||
step_time_sum = 0
|
||||
latent_history = [latents]
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(self.dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
guidance_scale = torch.Tensor([guidance_scale]).to(self.dtype)
|
||||
self.load_submodels(["unet"])
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
timestep = torch.tensor([t]).to(self.dtype).detach().numpy()
|
||||
latent_model_input = self.scheduler.scale_model_input(latents, t).to(
|
||||
self.dtype
|
||||
)
|
||||
if mask is not None and masked_image_latents is not None:
|
||||
latent_model_input = torch.cat(
|
||||
[
|
||||
torch.from_numpy(np.asarray(latent_model_input)).to(self.dtype),
|
||||
mask,
|
||||
masked_image_latents,
|
||||
],
|
||||
dim=1,
|
||||
).to(self.dtype)
|
||||
if cpu_scheduling:
|
||||
latent_model_input = latent_model_input.detach().numpy()
|
||||
|
||||
# Profiling Unet.
|
||||
# profile_device = start_profiling(file_path="unet.rdc")
|
||||
noise_pred = self.run(
|
||||
"unet",
|
||||
[
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
],
|
||||
)
|
||||
# end_profiling(profile_device)
|
||||
|
||||
if cpu_scheduling:
|
||||
noise_pred = torch.from_numpy(noise_pred.to_host())
|
||||
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
||||
else:
|
||||
latents = self.run("scheduler_step", (noise_pred, t, latents))
|
||||
|
||||
latent_history.append(latents)
|
||||
step_time = (time.time() - step_start_time) * 1000
|
||||
# print(
|
||||
# f"\n [LOG] step = {i} | timestep = {t} | time = {step_time:.2f}ms"
|
||||
# )
|
||||
step_time_sum += step_time
|
||||
|
||||
# if self.status == SD_STATE_CANCEL:
|
||||
# break
|
||||
|
||||
if self.ondemand:
|
||||
self.unload_submodels(["unet"])
|
||||
gc.collect()
|
||||
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
print(f"\n[LOG] Average step time: {avg_step_time}ms/it")
|
||||
|
||||
if not return_all_latents:
|
||||
return latents
|
||||
all_latents = torch.cat(latent_history, dim=0)
|
||||
return all_latents
|
||||
|
||||
def decode_latents(self, latents, cpu_scheduling=True):
|
||||
latents_numpy = latents.to(self.dtype)
|
||||
if cpu_scheduling:
|
||||
latents_numpy = latents.detach().numpy()
|
||||
|
||||
# profile_device = start_profiling(file_path="vae.rdc")
|
||||
vae_start = time.time()
|
||||
images = self.run("vae_decode", latents_numpy).to_host()
|
||||
vae_inf_time = (time.time() - vae_start) * 1000
|
||||
# end_profiling(profile_device)
|
||||
print(f"\n[LOG] VAE Inference time (ms): {vae_inf_time:.3f}")
|
||||
|
||||
images = torch.from_numpy(images).permute(0, 2, 3, 1).float().numpy()
|
||||
pil_images = self.image_processor.numpy_to_pil(images)
|
||||
return pil_images
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
scheduler,
|
||||
steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
seed,
|
||||
@@ -359,69 +167,15 @@ class StableDiffusion(SharkPipelineBase):
|
||||
control_mode,
|
||||
hints,
|
||||
):
|
||||
# TODO: Batched args
|
||||
self.image_processor = VaeImageProcessor(do_convert_rgb=True)
|
||||
self.scheduler = self.schedulers[scheduler]
|
||||
self.ondemand = ondemand
|
||||
if self.is_img2img:
|
||||
image, _ = self.image_processor.preprocess(image, resample_type)
|
||||
else:
|
||||
image = None
|
||||
|
||||
print("\n[LOG] Generating images...")
|
||||
batched_args = [
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
]
|
||||
for arg in batched_args:
|
||||
if not isinstance(arg, list):
|
||||
arg = [arg] * self.batch_size
|
||||
if len(arg) < self.batch_size:
|
||||
arg = arg * self.batch_size
|
||||
else:
|
||||
arg = [arg[i] for i in range(self.batch_size)]
|
||||
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
img = self.sd_pipe.generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
1,
|
||||
guidance_scale,
|
||||
seed,
|
||||
return_imgs=True,
|
||||
)
|
||||
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
init_latents, final_timesteps = self.prepare_latents(
|
||||
generator=generator,
|
||||
num_inference_steps=steps,
|
||||
image=image,
|
||||
strength=strength,
|
||||
)
|
||||
|
||||
latents = self.produce_img_latents(
|
||||
latents=init_latents,
|
||||
text_embeddings=text_embeddings,
|
||||
guidance_scale=guidance_scale,
|
||||
total_timesteps=final_timesteps,
|
||||
cpu_scheduling=True, # until we have schedulers through Turbine
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
all_imgs = []
|
||||
self.load_submodels(["vae_decode"])
|
||||
for i in tqdm(range(0, latents.shape[0], self.batch_size)):
|
||||
imgs = self.decode_latents(
|
||||
latents=latents[i : i + self.batch_size],
|
||||
cpu_scheduling=True,
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
if self.ondemand:
|
||||
self.unload_submodels(["vae_decode"])
|
||||
|
||||
return all_imgs
|
||||
return img
|
||||
|
||||
|
||||
def shark_sd_fn_dict_input(
|
||||
@@ -516,6 +270,8 @@ def shark_sd_fn(
|
||||
"num_loras": num_loras,
|
||||
"import_ir": cmd_opts.import_mlir,
|
||||
"is_controlled": is_controlled,
|
||||
"steps": steps,
|
||||
"scheduler": scheduler,
|
||||
}
|
||||
submit_prep_kwargs = {
|
||||
"custom_weights": custom_weights,
|
||||
@@ -527,8 +283,6 @@ def shark_sd_fn(
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"image": sd_init_image,
|
||||
"steps": steps,
|
||||
"scheduler": scheduler,
|
||||
"strength": strength,
|
||||
"guidance_scale": guidance_scale,
|
||||
"seed": seed,
|
||||
@@ -566,9 +320,9 @@ def shark_sd_fn(
|
||||
for current_batch in range(batch_count):
|
||||
start_time = time.time()
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"Total image(s) generation time: {total_time:.4f}sec"
|
||||
print(f"\n[LOG] {text_output}")
|
||||
# total_time = time.time() - start_time
|
||||
# text_output = f"Total image(s) generation time: {total_time:.4f}sec"
|
||||
#print(f"\n[LOG] {text_output}")
|
||||
# if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
# break
|
||||
# else:
|
||||
@@ -595,6 +349,9 @@ def view_json_file(file_path):
|
||||
content = fopen.read()
|
||||
return content
|
||||
|
||||
def safe_name(name):
|
||||
return name.replace("/", "_").replace("-", "_").replace("\\", "_").replace(".", "_")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
@@ -12,11 +12,6 @@ from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from cpuinfo import get_cpu_info
|
||||
|
||||
# TODO: migrate these utils to studio
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
set_iree_vulkan_runtime_flags,
|
||||
get_vulkan_target_triple,
|
||||
get_iree_vulkan_runtime_flags,
|
||||
)
|
||||
|
||||
|
||||
def get_available_devices():
|
||||
@@ -49,8 +44,6 @@ def get_available_devices():
|
||||
device_list.append(f"{device_name} => {driver_name}://{i}")
|
||||
return device_list
|
||||
|
||||
set_iree_runtime_flags()
|
||||
|
||||
available_devices = []
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
@@ -71,6 +64,8 @@ def get_available_devices():
|
||||
available_devices.extend(cuda_devices)
|
||||
rocm_devices = get_devices_by_name("rocm")
|
||||
available_devices.extend(rocm_devices)
|
||||
hip_devices = get_devices_by_name("hip")
|
||||
available_devices.extend(hip_devices)
|
||||
cpu_device = get_devices_by_name("cpu-sync")
|
||||
available_devices.extend(cpu_device)
|
||||
cpu_device = get_devices_by_name("cpu-task")
|
||||
@@ -78,54 +73,45 @@ def get_available_devices():
|
||||
return available_devices
|
||||
|
||||
|
||||
def set_init_device_flags():
|
||||
if "vulkan" in cmd_opts.device:
|
||||
# set runtime flags for vulkan.
|
||||
set_iree_runtime_flags()
|
||||
def parse_device(device_str):
|
||||
from shark.iree_utils.compile_utils import clean_device_info, get_iree_target_triple, iree_target_map
|
||||
rt_driver, device_id = clean_device_info(device_str)
|
||||
target_backend = iree_target_map(rt_driver)
|
||||
if device_id:
|
||||
rt_device = f"{rt_driver}://{device_id}"
|
||||
else:
|
||||
rt_device = rt_driver
|
||||
|
||||
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
|
||||
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
|
||||
if not cmd_opts.iree_vulkan_target_triple:
|
||||
triple = get_vulkan_target_triple(device_name)
|
||||
if triple is not None:
|
||||
cmd_opts.iree_vulkan_target_triple = triple
|
||||
print(
|
||||
f"Found device {device_name}. Using target triple "
|
||||
f"{cmd_opts.iree_vulkan_target_triple}."
|
||||
)
|
||||
elif "cuda" in cmd_opts.device:
|
||||
cmd_opts.device = "cuda"
|
||||
elif "metal" in cmd_opts.device:
|
||||
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
|
||||
if not cmd_opts.iree_metal_target_platform:
|
||||
from shark.iree_utils.metal_utils import get_metal_target_triple
|
||||
|
||||
triple = get_metal_target_triple(device_name)
|
||||
if triple is not None:
|
||||
cmd_opts.iree_metal_target_platform = triple.split("-")[-1]
|
||||
print(
|
||||
f"Found device {device_name}. Using target triple "
|
||||
f"{cmd_opts.iree_metal_target_platform}."
|
||||
)
|
||||
elif "cpu" in cmd_opts.device:
|
||||
cmd_opts.device = "cpu"
|
||||
match target_backend:
|
||||
case "vulkan-spirv":
|
||||
triple = get_iree_target_triple(device_str)
|
||||
return target_backend, rt_device, triple
|
||||
case "rocm":
|
||||
triple = get_rocm_target_chip(device_str)
|
||||
return target_backend, rt_device, triple
|
||||
case "cpu":
|
||||
return "llvm-cpu", "local-task", "x86_64-linux-gnu"
|
||||
|
||||
|
||||
def set_iree_runtime_flags():
|
||||
# TODO: This function should be device-agnostic and piped properly
|
||||
# to general runtime driver init.
|
||||
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
|
||||
if cmd_opts.enable_rgp:
|
||||
vulkan_runtime_flags += [
|
||||
f"--enable_rgp=true",
|
||||
f"--vulkan_debug_utils=true",
|
||||
]
|
||||
if cmd_opts.device_allocator_heap_key:
|
||||
vulkan_runtime_flags += [
|
||||
f"--device_allocator=caching:device_local={cmd_opts.device_allocator_heap_key}",
|
||||
]
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
def get_rocm_target_chip(device_str):
|
||||
#TODO: Use a data file to map device_str to target chip.
|
||||
rocm_chip_map = {
|
||||
"6700": "gfx1031",
|
||||
"6800": "gfx1030",
|
||||
"6900": "gfx1030",
|
||||
"7900": "gfx1100",
|
||||
"MI300X": "gfx942",
|
||||
"MI300A": "gfx940",
|
||||
"MI210": "gfx90a",
|
||||
"MI250": "gfx90a",
|
||||
"MI100": "gfx908",
|
||||
"MI50": "gfx906",
|
||||
"MI60": "gfx906",
|
||||
}
|
||||
for key in rocm_chip_map:
|
||||
if key in device_str:
|
||||
return rocm_chip_map[key]
|
||||
raise AssertionError(f"Device {device_str} not recognized. Please file an issue at https://github.com/nod-ai/SHARK/issues.")
|
||||
|
||||
def get_all_devices(driver_name):
|
||||
"""
|
||||
|
||||
@@ -45,11 +45,10 @@ from apps.shark_studio.modules import logger
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
|
||||
sd_default_models = [
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-xl-1.0",
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"stabilityai/sdxl-turbo",
|
||||
]
|
||||
|
||||
@@ -286,14 +285,14 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
384,
|
||||
768,
|
||||
1024,
|
||||
value=cmd_opts.height,
|
||||
step=8,
|
||||
label="\U00002195\U0000FE0F Height",
|
||||
)
|
||||
width = gr.Slider(
|
||||
384,
|
||||
768,
|
||||
1024,
|
||||
value=cmd_opts.width,
|
||||
step=8,
|
||||
label="\U00002194\U0000FE0F Width",
|
||||
|
||||
@@ -5,9 +5,10 @@
|
||||
setuptools
|
||||
wheel
|
||||
|
||||
torch==2.3.0
|
||||
shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main
|
||||
turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@main#subdirectory=models
|
||||
torch>=2.3.0
|
||||
shark-turbine @ git+https://github.com/nod-ai/SHARK-Turbine.git@ean-sdxl-fixes#subdirectory=core
|
||||
turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@ean-sdxl-fixes#subdirectory=models
|
||||
diffusers @ git+https://github.com/nod-ai/diffusers@v0.24.0-release
|
||||
|
||||
# SHARK Runner
|
||||
tqdm
|
||||
|
||||
@@ -88,5 +88,7 @@ else {python -m venv .\shark.venv\}
|
||||
.\shark.venv\Scripts\activate
|
||||
python -m pip install --upgrade pip
|
||||
pip install wheel
|
||||
pip install -r requirements.txt
|
||||
pip install --pre -r requirements.txt
|
||||
|
||||
>>>>>>> 0c904eb7 (Shark Studio SDXL support, HIP driver support, simpler device info, small fixes)
|
||||
Write-Host "Source your venv with ./shark.venv/Scripts/activate"
|
||||
|
||||
@@ -76,6 +76,7 @@ _IREE_DEVICE_MAP = {
|
||||
"vulkan": "vulkan",
|
||||
"metal": "metal",
|
||||
"rocm": "rocm",
|
||||
"hip": "hip",
|
||||
"intel-gpu": "level_zero",
|
||||
}
|
||||
|
||||
@@ -94,6 +95,7 @@ _IREE_TARGET_MAP = {
|
||||
"vulkan": "vulkan-spirv",
|
||||
"metal": "metal",
|
||||
"rocm": "rocm",
|
||||
"hip": "rocm",
|
||||
"intel-gpu": "opencl-spirv",
|
||||
}
|
||||
|
||||
|
||||
@@ -62,6 +62,9 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args)
|
||||
if device == "hip":
|
||||
from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args, hip_driver=True)
|
||||
return []
|
||||
|
||||
def get_iree_target_triple(device):
|
||||
|
||||
@@ -52,7 +52,7 @@ def check_rocm_device_arch_in_args(extra_args):
|
||||
return None
|
||||
|
||||
|
||||
def get_rocm_device_arch(device_num=0, extra_args=[]):
|
||||
def get_rocm_device_arch(device_num=0, extra_args=[], hip_driver=False):
|
||||
# ROCM Device Arch selection:
|
||||
# 1 : User given device arch using `--iree-rocm-target-chip` flag
|
||||
# 2 : Device arch from `iree-run-module --dump_devices=rocm` for device on index <device_num>
|
||||
@@ -68,15 +68,23 @@ def get_rocm_device_arch(device_num=0, extra_args=[]):
|
||||
arch_in_device_dump = None
|
||||
|
||||
# get rocm arch from iree dump devices
|
||||
def get_devices_info_from_dump(dump):
|
||||
def get_devices_info_from_dump(dump, driver):
|
||||
from os import linesep
|
||||
|
||||
dump_clean = list(
|
||||
filter(
|
||||
lambda s: "--device=rocm" in s or "gpu-arch-name:" in s,
|
||||
dump.split(linesep),
|
||||
|
||||
if driver == "hip":
|
||||
dump_clean = list(
|
||||
filter(
|
||||
lambda s: "AMD" in s,
|
||||
dump.split(linesep),
|
||||
)
|
||||
)
|
||||
else:
|
||||
dump_clean = list(
|
||||
filter(
|
||||
lambda s: f"--device={driver}" in s or "gpu-arch-name:" in s,
|
||||
dump.split(linesep),
|
||||
)
|
||||
)
|
||||
)
|
||||
arch_pairs = [
|
||||
(
|
||||
dump_clean[i].split("=")[1].strip(),
|
||||
@@ -87,16 +95,17 @@ def get_rocm_device_arch(device_num=0, extra_args=[]):
|
||||
return arch_pairs
|
||||
|
||||
dump_device_info = None
|
||||
driver = "hip" if hip_driver else "rocm"
|
||||
try:
|
||||
dump_device_info = run_cmd(
|
||||
"iree-run-module --dump_devices=rocm", raise_err=True
|
||||
"iree-run-module --dump_devices=" + driver, raise_err=True
|
||||
)
|
||||
except Exception as e:
|
||||
print("could not execute `iree-run-module --dump_devices=rocm`")
|
||||
print("could not execute `iree-run-module --dump_devices=" + driver + "`")
|
||||
|
||||
if dump_device_info is not None:
|
||||
device_num = 0 if device_num is None else device_num
|
||||
device_arch_pairs = get_devices_info_from_dump(dump_device_info[0])
|
||||
device_arch_pairs = get_devices_info_from_dump(dump_device_info[0], driver)
|
||||
if len(device_arch_pairs) > device_num: # can find arch in the list
|
||||
arch_in_device_dump = device_arch_pairs[device_num][1]
|
||||
|
||||
@@ -107,24 +116,22 @@ def get_rocm_device_arch(device_num=0, extra_args=[]):
|
||||
default_rocm_arch = "gfx1100"
|
||||
print(
|
||||
"Did not find ROCm architecture from `--iree-rocm-target-chip` flag"
|
||||
"\n or from `iree-run-module --dump_devices=rocm` command."
|
||||
"\n or from `iree-run-module --dump_devices` command."
|
||||
f"\nUsing {default_rocm_arch} as ROCm arch for compilation."
|
||||
)
|
||||
return default_rocm_arch
|
||||
|
||||
|
||||
# Get the default gpu args given the architecture.
|
||||
def get_iree_rocm_args(device_num=0, extra_args=[]):
|
||||
def get_iree_rocm_args(device_num=0, extra_args=[], hip_driver=False):
|
||||
ireert.flags.FUNCTION_INPUT_VALIDATION = False
|
||||
rocm_flags = ["--iree-rocm-link-bc=true"]
|
||||
|
||||
rocm_flags = []
|
||||
if check_rocm_device_arch_in_args(extra_args) is None:
|
||||
rocm_arch = get_rocm_device_arch(device_num, extra_args)
|
||||
rocm_arch = get_rocm_device_arch(device_num, extra_args, hip_driver=hip_driver)
|
||||
rocm_flags.append(f"--iree-rocm-target-chip={rocm_arch}")
|
||||
|
||||
return rocm_flags
|
||||
|
||||
|
||||
# Some constants taken from cuda.h
|
||||
CUDA_SUCCESS = 0
|
||||
CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT = 16
|
||||
|
||||
Reference in New Issue
Block a user