mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
add vulkan-heap-block-size flag (#498)
This commit is contained in:
committed by
GitHub
parent
005ded3c6f
commit
e67bcffea7
@@ -5,7 +5,7 @@ from diffusers import LMSDiscreteScheduler
|
||||
from tqdm.auto import tqdm
|
||||
import numpy as np
|
||||
from stable_args import args
|
||||
from utils import get_shark_model
|
||||
from utils import get_shark_model, set_iree_runtime_flags
|
||||
from opt_params import get_unet, get_vae, get_clip
|
||||
import time
|
||||
|
||||
@@ -46,6 +46,7 @@ if __name__ == "__main__":
|
||||
|
||||
batch_size = len(prompt)
|
||||
|
||||
set_iree_runtime_flags()
|
||||
unet = get_unet()
|
||||
vae = get_vae()
|
||||
clip = get_clip()
|
||||
|
||||
@@ -97,4 +97,10 @@ p.add_argument(
|
||||
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_large_heap_block_size",
|
||||
default="4294967296",
|
||||
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
|
||||
)
|
||||
|
||||
args = p.parse_args()
|
||||
|
||||
@@ -4,6 +4,7 @@ import torch
|
||||
from shark.shark_inference import SharkInference
|
||||
from stable_args import args
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.iree_utils.vulkan_utils import set_iree_vulkan_runtime_flags
|
||||
|
||||
|
||||
def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
@@ -16,6 +17,7 @@ def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
extended_name = "{}_{}".format(model_name, device)
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
|
||||
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
|
||||
print(f"loading existing vmfb from: {vmfb_path}")
|
||||
shark_module.load_module(vmfb_path, extra_args=extra_args)
|
||||
else:
|
||||
if args.save_vmfb:
|
||||
@@ -61,3 +63,14 @@ def compile_through_fx(model, inputs, model_name, extra_args=[]):
|
||||
)
|
||||
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
|
||||
|
||||
def set_iree_runtime_flags():
|
||||
|
||||
vulkan_runtime_flags = [
|
||||
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
|
||||
]
|
||||
if "vulkan" in args.device:
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
return
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
from os import linesep
|
||||
from shark.iree_utils._common import run_cmd
|
||||
import iree.runtime as ireert
|
||||
|
||||
|
||||
def get_vulkan_device_name():
|
||||
@@ -68,3 +69,9 @@ def get_iree_vulkan_args(extra_args=[]):
|
||||
if vulkan_triple_flag is not None:
|
||||
vulkan_flag.append(vulkan_triple_flag)
|
||||
return vulkan_flag
|
||||
|
||||
|
||||
def set_iree_vulkan_runtime_flags(flags):
|
||||
for flag in flags:
|
||||
ireert.flags.parse_flags(flag)
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user