mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-08 21:38:04 -05:00
fix formatting and disable explicit vulkan env settings.
This commit is contained in:
@@ -84,7 +84,7 @@ def dumpstacks():
|
||||
if line:
|
||||
code.append(" " + line.strip())
|
||||
with open(os.path.join(shark_tmp, "stack_dump.log"), "w") as f:
|
||||
f.write("\n".join(code))
|
||||
f.write("\n".join(code))
|
||||
|
||||
|
||||
def setup_middleware(app):
|
||||
|
||||
@@ -326,6 +326,7 @@ class LanguageModel:
|
||||
self.global_iter += 1
|
||||
return result_output, total_time
|
||||
|
||||
|
||||
def llm_chat_api(InputData: dict):
|
||||
from datetime import datetime as dt
|
||||
|
||||
@@ -392,7 +393,6 @@ def llm_chat_api(InputData: dict):
|
||||
print("prompt = ", prompt)
|
||||
|
||||
for res_op, _ in llm_model.chat(prompt):
|
||||
|
||||
if is_chat_completion_api:
|
||||
choices = [
|
||||
{
|
||||
@@ -421,6 +421,7 @@ def llm_chat_api(InputData: dict):
|
||||
"choices": choices,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
lm = LanguageModel(
|
||||
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
|
||||
|
||||
@@ -83,9 +83,7 @@ class StableDiffusion(SharkPipelineBase):
|
||||
"clip": {"hf_model_name": base_model_id},
|
||||
"unet": {
|
||||
"hf_model_name": base_model_id,
|
||||
"unet_model": unet.UnetModel(
|
||||
hf_model_name=base_model_id
|
||||
),
|
||||
"unet_model": unet.UnetModel(hf_model_name=base_model_id),
|
||||
"batch_size": batch_size,
|
||||
# "is_controlled": is_controlled,
|
||||
# "num_loras": num_loras,
|
||||
|
||||
@@ -77,6 +77,7 @@ def get_available_devices():
|
||||
available_devices.extend(cpu_device)
|
||||
return available_devices
|
||||
|
||||
|
||||
def set_init_device_flags():
|
||||
if "vulkan" in cmd_opts.device:
|
||||
# set runtime flags for vulkan.
|
||||
@@ -109,6 +110,7 @@ def set_init_device_flags():
|
||||
elif "cpu" in cmd_opts.device:
|
||||
cmd_opts.device = "cpu"
|
||||
|
||||
|
||||
def set_iree_runtime_flags():
|
||||
# TODO: This function should be device-agnostic and piped properly
|
||||
# to general runtime driver init.
|
||||
@@ -177,6 +179,7 @@ def get_device_mapping(driver, key_combination=3):
|
||||
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
|
||||
return device_map
|
||||
|
||||
|
||||
def get_opt_flags(model, precision="fp16"):
|
||||
iree_flags = []
|
||||
if len(cmd_opts.iree_vulkan_target_triple) > 0:
|
||||
@@ -202,6 +205,7 @@ def get_opt_flags(model, precision="fp16"):
|
||||
iree_flags += ["--iree-flow-collapse-reduction-dims"]
|
||||
return iree_flags
|
||||
|
||||
|
||||
def map_device_to_name_path(device, key_combination=3):
|
||||
"""Gives the appropriate device data (supported name/path) for user
|
||||
selected execution device
|
||||
|
||||
@@ -90,6 +90,8 @@ class SharkPipelineBase:
|
||||
)
|
||||
|
||||
weights_path = self.get_io_params(submodel)
|
||||
if weights_path:
|
||||
ireec_flags.append("--iree-opt-const-eval=False")
|
||||
|
||||
self.iree_module_dict[submodel] = get_iree_compiled_module(
|
||||
self.tempfiles[submodel],
|
||||
|
||||
@@ -56,9 +56,7 @@ datas += [
|
||||
# hidden imports for pyinstaller
|
||||
hiddenimports = ["shark", "apps"]
|
||||
hiddenimports += [x for x in collect_submodules("gradio") if "tests" not in x]
|
||||
hiddenimports += [
|
||||
x for x in collect_submodules("diffusers") if "tests" not in x
|
||||
]
|
||||
hiddenimports += [x for x in collect_submodules("diffusers") if "tests" not in x]
|
||||
blacklist = ["tests", "convert"]
|
||||
hiddenimports += [
|
||||
x
|
||||
@@ -67,4 +65,4 @@ hiddenimports += [
|
||||
]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "test" not in x]
|
||||
hiddenimports += ["iree._runtime"]
|
||||
hiddenimports += [x for x in collect_submodules("scipy") if "test" not in x]
|
||||
hiddenimports += [x for x in collect_submodules("scipy") if "test" not in x]
|
||||
|
||||
@@ -38,7 +38,6 @@ def llm_chat_test(verbose=False):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# "Exercises the chatbot REST API of Shark. Make sure "
|
||||
# "Shark is running in API mode on 127.0.0.1:8080 before running"
|
||||
# "this script."
|
||||
|
||||
@@ -26,7 +26,6 @@ from apps.shark_studio.api.llm import llm_chat_api
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("http://") or encoding.startswith("https://"):
|
||||
|
||||
headers = {}
|
||||
response = requests.get(encoding, timeout=30, headers=headers)
|
||||
try:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from multiprocessing import Process, freeze_support
|
||||
|
||||
freeze_support()
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@@ -632,7 +632,9 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
with gr.Tab(label="Config", id=102) as sd_tab_config:
|
||||
with gr.Column(elem_classes=["sd-right-panel"]):
|
||||
with gr.Row(elem_classes=["fill"]):
|
||||
Path(get_configs_path()).mkdir(parents=True, exist_ok=True)
|
||||
Path(get_configs_path()).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
default_config_file = os.path.join(
|
||||
get_configs_path(),
|
||||
"default_sd_config.json",
|
||||
|
||||
@@ -40,10 +40,12 @@ default_sd_config = r"""{
|
||||
"embeddings": {}
|
||||
}"""
|
||||
|
||||
|
||||
def write_default_sd_config(path):
|
||||
with open(path, "w") as f:
|
||||
f.write(default_sd_config)
|
||||
|
||||
|
||||
def safe_name(name):
|
||||
return name.replace("/", "_").replace("-", "_")
|
||||
|
||||
|
||||
@@ -113,8 +113,8 @@ def get_iree_frontend_args(frontend):
|
||||
# Common args to be used given any frontend or device.
|
||||
def get_iree_common_args(debug=False):
|
||||
common_args = [
|
||||
"--iree-vm-bytecode-module-strip-source-map=true",
|
||||
"--iree-util-zero-fill-elided-attrs",
|
||||
"--mlir-elide-elementsattrs-if-larger=10",
|
||||
]
|
||||
if debug == True:
|
||||
common_args.extend(
|
||||
|
||||
@@ -33,7 +33,7 @@ def get_vulkan_target_env(vulkan_target_triple):
|
||||
device_type = get_device_type(triple)
|
||||
# get capabilities
|
||||
capabilities = get_vulkan_target_capabilities(triple)
|
||||
target_env = f"#vk.target_env<{version}, r({revision}), {extensions}, {vendor}:{device_type}, #vk.caps< {capabilities} >>"
|
||||
target_env = f"<#spirv.vce<{version}, r({revision}), {extensions}>, {vendor}:{device_type}, #spirv.resource_limits< {capabilities} >>"
|
||||
return target_env
|
||||
|
||||
|
||||
@@ -63,62 +63,62 @@ def get_extensions(triple):
|
||||
arch, product, os = triple
|
||||
if arch == "m1":
|
||||
ext = [
|
||||
"VK_KHR_16bit_storage",
|
||||
"VK_KHR_8bit_storage",
|
||||
"VK_KHR_shader_float16_int8",
|
||||
"VK_KHR_storage_buffer_storage_class",
|
||||
"VK_KHR_variable_pointers",
|
||||
"SPV_KHR_16bit_storage",
|
||||
"SPV_KHR_8bit_storage",
|
||||
"SPV_KHR_shader_float16_int8",
|
||||
"SPV_KHR_storage_buffer_storage_class",
|
||||
"SPV_KHR_variable_pointers",
|
||||
]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if arch == "valhall":
|
||||
ext = [
|
||||
"VK_KHR_16bit_storage",
|
||||
"VK_KHR_8bit_storage",
|
||||
"VK_KHR_shader_float16_int8",
|
||||
"VK_KHR_spirv_1_4",
|
||||
"VK_KHR_storage_buffer_storage_class",
|
||||
"VK_KHR_variable_pointers",
|
||||
"SPV_KHR_16bit_storage",
|
||||
"SPV_KHR_8bit_storage",
|
||||
"SPV_KHR_shader_float16_int8",
|
||||
"SPV_KHR_spirv_1_4",
|
||||
"SPV_KHR_storage_buffer_storage_class",
|
||||
"SPV_KHR_variable_pointers",
|
||||
]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if arch == "adreno":
|
||||
ext = [
|
||||
"VK_KHR_16bit_storage",
|
||||
"VK_KHR_shader_float16_int8",
|
||||
"VK_KHR_spirv_1_4",
|
||||
"VK_KHR_storage_buffer_storage_class",
|
||||
"VK_KHR_variable_pointers",
|
||||
"SPV_KHR_16bit_storage",
|
||||
"SPV_KHR_shader_float16_int8",
|
||||
"SPV_KHR_spirv_1_4",
|
||||
"SPV_KHR_storage_buffer_storage_class",
|
||||
"SPV_KHR_variable_pointers",
|
||||
]
|
||||
if os == "android31":
|
||||
ext.append("VK_KHR_8bit_storage")
|
||||
ext.append("SPV_KHR_8bit_storage")
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if get_vendor(triple) == "SwiftShader":
|
||||
ext = ["VK_KHR_storage_buffer_storage_class"]
|
||||
ext = ["SPV_KHR_storage_buffer_storage_class"]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if arch == "unknown":
|
||||
ext = [
|
||||
"VK_KHR_storage_buffer_storage_class",
|
||||
"VK_KHR_variable_pointers",
|
||||
"SPV_KHR_storage_buffer_storage_class",
|
||||
"SPV_KHR_variable_pointers",
|
||||
]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
ext = [
|
||||
"VK_KHR_16bit_storage",
|
||||
"VK_KHR_8bit_storage",
|
||||
"VK_KHR_shader_float16_int8",
|
||||
"VK_KHR_spirv_1_4",
|
||||
"VK_KHR_storage_buffer_storage_class",
|
||||
"VK_KHR_variable_pointers",
|
||||
"SPV_KHR_16bit_storage",
|
||||
"SPV_KHR_8bit_storage",
|
||||
"SPV_KHR_shader_float16_int8",
|
||||
"SPV_KHR_spirv_1_4",
|
||||
"SPV_KHR_storage_buffer_storage_class",
|
||||
"SPV_KHR_variable_pointers",
|
||||
"VK_EXT_subgroup_size_control",
|
||||
]
|
||||
|
||||
if get_vendor(triple) == "NVIDIA" or arch == "rdna3":
|
||||
ext.append("VK_KHR_cooperative_matrix")
|
||||
ext.append("SPV_KHR_cooperative_matrix")
|
||||
if get_vendor(triple) == ["NVIDIA", "AMD", "Intel"]:
|
||||
ext.append("VK_KHR_shader_integer_dot_product")
|
||||
ext.append("SPV_KHR_shader_integer_dot_product")
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
|
||||
@@ -186,13 +186,13 @@ def get_vulkan_target_capabilities(triple):
|
||||
"Quad": 128,
|
||||
"PartitionedNV": 256,
|
||||
}
|
||||
cap["maxComputeSharedMemorySize"] = 16384
|
||||
cap["maxComputeWorkGroupInvocations"] = 128
|
||||
cap["maxComputeWorkGroupSize"] = [128, 128, 64]
|
||||
cap["subgroupSize"] = 32
|
||||
cap["max_compute_shared_memory_size"] = 16384
|
||||
cap["max_compute_workgroup_invocations"] = 128
|
||||
cap["max_compute_workgroup_size"] = [128, 128, 64]
|
||||
cap["subgroup_size"] = 32
|
||||
cap["subgroupFeatures"] = ["Basic"]
|
||||
cap["minSubgroupSize"] = None
|
||||
cap["maxSubgroupSize"] = None
|
||||
cap["min_subgroup_size"] = None
|
||||
cap["max_subgroup_size"] = None
|
||||
cap["shaderFloat16"] = False
|
||||
cap["shaderFloat64"] = False
|
||||
cap["shaderInt8"] = False
|
||||
@@ -209,13 +209,13 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["coopmatCases"] = None
|
||||
|
||||
if arch in ["rdna1", "rdna2", "rdna3"]:
|
||||
cap["maxComputeSharedMemorySize"] = 65536
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
|
||||
cap["max_compute_shared_memory_size"] = 65536
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroupSize"] = 64
|
||||
cap["minSubgroupSize"] = 32
|
||||
cap["maxSubgroupSize"] = 64
|
||||
cap["subgroup_size"] = 64
|
||||
cap["min_subgroup_size"] = 32
|
||||
cap["max_subgroup_size"] = 64
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
@@ -244,7 +244,8 @@ def get_vulkan_target_capabilities(triple):
|
||||
if arch == "rdna3":
|
||||
# TODO: Get scope value
|
||||
cap["coopmatCases"] = [
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, accSat = false, scope = #vk.scope<Subgroup>"
|
||||
"m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>",
|
||||
"m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, acc_sat = false, scope = <Subgroup>"
|
||||
]
|
||||
|
||||
if product == "rx5700xt":
|
||||
@@ -252,11 +253,11 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["storagePushConstant8"] = False
|
||||
|
||||
elif arch in ["rgcn5", "rgcn4", "rgcn3"]:
|
||||
cap["maxComputeSharedMemorySize"] = 65536
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
|
||||
cap["max_compute_shared_memory_size"] = 65536
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroupSize"] = 64
|
||||
cap["subgroup_size"] = 64
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
@@ -267,8 +268,8 @@ def get_vulkan_target_capabilities(triple):
|
||||
"Clustered",
|
||||
"Quad",
|
||||
]
|
||||
cap["minSubgroupSize"] = 64
|
||||
cap["maxSubgroupSize"] = 64
|
||||
cap["min_subgroup_size"] = 64
|
||||
cap["max_subgroup_size"] = 64
|
||||
|
||||
if arch == "rgcn5":
|
||||
cap["shaderFloat16"] = True
|
||||
@@ -290,11 +291,11 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "m1":
|
||||
cap["maxComputeSharedMemorySize"] = 32768
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
|
||||
cap["max_compute_shared_memory_size"] = 32768
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroupSize"] = 32
|
||||
cap["subgroup_size"] = 32
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
@@ -321,11 +322,11 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "valhall":
|
||||
cap["maxComputeSharedMemorySize"] = 32768
|
||||
cap["maxComputeWorkGroupInvocations"] = 512
|
||||
cap["maxComputeWorkGroupSize"] = [512, 512, 512]
|
||||
cap["max_compute_shared_memory_size"] = 32768
|
||||
cap["max_compute_workgroup_invocations"] = 512
|
||||
cap["max_compute_workgroup_size"] = [512, 512, 512]
|
||||
|
||||
cap["subgroupSize"] = 16
|
||||
cap["subgroup_size"] = 16
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
@@ -352,11 +353,11 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "arc":
|
||||
cap["maxComputeSharedMemorySize"] = 32768
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 64]
|
||||
cap["max_compute_shared_memory_size"] = 32768
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 64]
|
||||
|
||||
cap["subgroupSize"] = 32
|
||||
cap["subgroup_size"] = 32
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
@@ -385,8 +386,8 @@ def get_vulkan_target_capabilities(triple):
|
||||
|
||||
elif arch == "cpu":
|
||||
if product == "swiftshader":
|
||||
cap["maxComputeSharedMemorySize"] = 16384
|
||||
cap["subgroupSize"] = 4
|
||||
cap["max_compute_shared_memory_size"] = 16384
|
||||
cap["subgroup_size"] = 4
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
@@ -397,13 +398,13 @@ def get_vulkan_target_capabilities(triple):
|
||||
]
|
||||
|
||||
elif arch in ["pascal"]:
|
||||
cap["maxComputeSharedMemorySize"] = 49152
|
||||
cap["maxComputeWorkGroupInvocations"] = 1536
|
||||
cap["maxComputeWorkGroupSize"] = [1536, 1024, 64]
|
||||
cap["max_compute_shared_memory_size"] = 49152
|
||||
cap["max_compute_workgroup_invocations"] = 1536
|
||||
cap["max_compute_workgroup_size"] = [1536, 1024, 64]
|
||||
|
||||
cap["subgroupSize"] = 32
|
||||
cap["minSubgroupSize"] = 32
|
||||
cap["maxSubgroupSize"] = 32
|
||||
cap["subgroup_size"] = 32
|
||||
cap["min_subgroup_size"] = 32
|
||||
cap["max_subgroup_size"] = 32
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
@@ -431,13 +432,13 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch in ["ampere", "turing"]:
|
||||
cap["maxComputeSharedMemorySize"] = 49152
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
|
||||
cap["max_compute_shared_memory_size"] = 49152
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroupSize"] = 32
|
||||
cap["minSubgroupSize"] = 32
|
||||
cap["maxSubgroupSize"] = 32
|
||||
cap["subgroup_size"] = 32
|
||||
cap["min_subgroup_size"] = 32
|
||||
cap["max_subgroup_size"] = 32
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
@@ -471,11 +472,11 @@ def get_vulkan_target_capabilities(triple):
|
||||
]
|
||||
|
||||
elif arch == "adreno":
|
||||
cap["maxComputeSharedMemorySize"] = 32768
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 64]
|
||||
cap["max_compute_shared_memory_size"] = 32768
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 64]
|
||||
|
||||
cap["subgroupSize"] = 64
|
||||
cap["subgroup_size"] = 64
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
@@ -491,14 +492,14 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["shaderInt16"] = True
|
||||
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
if os == "andorid31":
|
||||
if os == "android31":
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "unknown":
|
||||
cap["subgroupSize"] = 64
|
||||
cap["subgroup_size"] = 64
|
||||
cap["variablePointers"] = False
|
||||
cap["variablePointersStorageBuffer"] = False
|
||||
else:
|
||||
@@ -521,14 +522,14 @@ def get_vulkan_target_capabilities(triple):
|
||||
res += f"{k} = {'unit' if v == True else None}, "
|
||||
elif isinstance(v, list):
|
||||
if k == "subgroupFeatures":
|
||||
res += f"subgroupFeatures = {get_subgroup_val(v)}: i32, "
|
||||
elif k == "maxComputeWorkGroupSize":
|
||||
res += f"maxComputeWorkGroupSize = dense<{get_comma_sep_str(v)}>: vector<{len(v)}xi32>, "
|
||||
res += f"subgroup_features = {get_subgroup_val(v)}: i32, "
|
||||
elif k == "max_compute_workgroup_size":
|
||||
res += f"max_compute_workgroup_size = dense<{get_comma_sep_str(v)}>: vector<{len(v)}xi32>, "
|
||||
elif k == "coopmatCases":
|
||||
cmc = ""
|
||||
for case in v:
|
||||
cmc += f"#vk.coop_matrix_props<{case}>, "
|
||||
res += f"cooperativeMatrixPropertiesKHR = [{cmc[:-2]}], "
|
||||
cmc += f"#spirv.coop_matrix_props_khr<{case}>, "
|
||||
res += f"cooperative_matrix_properties_khr = [{cmc[:-2]}], "
|
||||
else:
|
||||
res += f"{k} = {get_comma_sep_str(v)}, "
|
||||
else:
|
||||
|
||||
@@ -144,6 +144,8 @@ def get_vulkan_target_triple(device_name):
|
||||
# Intel Targets
|
||||
elif any(x in device_name for x in ("A770", "A750")):
|
||||
triple = f"arc-770-{system_os}"
|
||||
elif "v620" in device_name:
|
||||
triple = f"rdna2-v620-{system_os}"
|
||||
|
||||
# Adreno Targets
|
||||
elif all(x in device_name for x in ("Adreno", "740")):
|
||||
@@ -169,7 +171,7 @@ def get_vulkan_triple_flag(device_name="", device_num=0, extra_args=[]):
|
||||
print(
|
||||
f"Found vulkan device {vulkan_device}. Using target triple {triple}"
|
||||
)
|
||||
return f"-iree-vulkan-target-triple={triple}"
|
||||
return f"--iree-vulkan-target-triple={triple}"
|
||||
print(
|
||||
"""Optimized kernel for your target device is not added yet.
|
||||
Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
|
||||
@@ -184,7 +186,8 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
|
||||
|
||||
res_vulkan_flag = []
|
||||
res_vulkan_flag += [
|
||||
"--iree-stream-resource-max-allocation-size=3221225472"
|
||||
"--iree-stream-resource-max-allocation-size=3221225472",
|
||||
"--iree-flow-inline-constants-max-byte-length=0"
|
||||
]
|
||||
vulkan_triple_flag = None
|
||||
for arg in extra_args:
|
||||
@@ -197,6 +200,7 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
|
||||
vulkan_triple_flag = get_vulkan_triple_flag(
|
||||
device_num=device_num, extra_args=extra_args
|
||||
)
|
||||
res_vulkan_flag += [vulkan_triple_flag]
|
||||
|
||||
if vulkan_triple_flag is not None:
|
||||
vulkan_target_env = get_vulkan_target_env_flag(vulkan_triple_flag)
|
||||
|
||||
Reference in New Issue
Block a user