add vulkan-heap-block-size flag (#498)

This commit is contained in:
Phaneesh Barwaria
2022-11-22 13:30:25 +05:30
committed by GitHub
parent 005ded3c6f
commit e67bcffea7
4 changed files with 28 additions and 1 deletions

View File

@@ -5,7 +5,7 @@ from diffusers import LMSDiscreteScheduler
from tqdm.auto import tqdm
import numpy as np
from stable_args import args
from utils import get_shark_model
from utils import get_shark_model, set_iree_runtime_flags
from opt_params import get_unet, get_vae, get_clip
import time
@@ -46,6 +46,7 @@ if __name__ == "__main__":
batch_size = len(prompt)
set_iree_runtime_flags()
unet = get_unet()
vae = get_vae()
clip = get_clip()

View File

@@ -97,4 +97,10 @@ p.add_argument(
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
)
p.add_argument(
"--vulkan_large_heap_block_size",
default="4294967296",
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
)
args = p.parse_args()

View File

@@ -4,6 +4,7 @@ import torch
from shark.shark_inference import SharkInference
from stable_args import args
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import set_iree_vulkan_runtime_flags
def _compile_module(shark_module, model_name, extra_args=[]):
@@ -16,6 +17,7 @@ def _compile_module(shark_module, model_name, extra_args=[]):
extended_name = "{}_{}".format(model_name, device)
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
print(f"loading existing vmfb from: {vmfb_path}")
shark_module.load_module(vmfb_path, extra_args=extra_args)
else:
if args.save_vmfb:
@@ -61,3 +63,14 @@ def compile_through_fx(model, inputs, model_name, extra_args=[]):
)
return _compile_module(shark_module, model_name, extra_args)
def set_iree_runtime_flags():
vulkan_runtime_flags = [
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
]
if "vulkan" in args.device:
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
return

View File

@@ -16,6 +16,7 @@
from os import linesep
from shark.iree_utils._common import run_cmd
import iree.runtime as ireert
def get_vulkan_device_name():
@@ -68,3 +69,9 @@ def get_iree_vulkan_args(extra_args=[]):
if vulkan_triple_flag is not None:
vulkan_flag.append(vulkan_triple_flag)
return vulkan_flag
def set_iree_vulkan_runtime_flags(flags):
for flag in flags:
ireert.flags.parse_flags(flag)
return