set task_topology_max_group to cpu_count (#1594)

by default. Can be overriden with a flag of the same str
This commit is contained in:
Daniel Garvey
2023-06-26 16:54:06 -05:00
committed by GitHub
parent 74a7202173
commit 75672c0e28
3 changed files with 29 additions and 2 deletions

View File

@@ -14,6 +14,7 @@
import iree.runtime as ireert
import iree.compiler as ireec
from shark.iree_utils._common import iree_device_map, iree_target_map
from shark.iree_utils.cpu_utils import get_iree_cpu_rt_args
from shark.iree_utils.benchmark_utils import *
from shark.parser import shark_args
import numpy as np
@@ -352,6 +353,12 @@ def load_vmfb_using_mmap(
config = ireert.Config(device=haldevice)
else:
config = get_iree_runtime_config(device)
if "task" in device:
print(
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
)
for flag in get_iree_cpu_rt_args():
ireert.flags.parse_flags(flag)
# Now load vmfb.
# Two scenarios we have here :-
# 1. We either have the vmfb already saved and therefore pass the path of it.
@@ -359,7 +366,6 @@ def load_vmfb_using_mmap(
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
# (This would arise if we're invoking `compile` from a SharkInference obj)
temp_file_to_unlink = None
if isinstance(flatbuffer_blob_or_path, Path):
flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__()
if (

View File

@@ -16,6 +16,7 @@
import subprocess
import platform
from shark.parser import shark_args
def get_cpu_count():
@@ -44,4 +45,18 @@ def get_iree_cpu_args():
error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)"
raise Exception(error_message)
print(f"Target triple found:{target_triple}")
return [f"--iree-llvmcpu-target-triple={target_triple}"]
return [
f"--iree-llvmcpu-target-triple={target_triple}",
]
# Get iree runtime flags for cpu
def get_iree_cpu_rt_args():
default = get_cpu_count()
default = default if default <= 8 else default - 2
cpu_count = (
default
if shark_args.task_topology_max_group_count is None
else shark_args.task_topology_max_group_count
)
return [f"--task_topology_max_group_count={cpu_count}"]

View File

@@ -119,5 +119,11 @@ parser.add_argument(
"to augment the base device allocator",
choices=["debug", "caching"],
)
parser.add_argument(
"--task_topology_max_group_count",
type=str,
default=None,
help="passthrough flag for the iree flag of the same name. If None, defaults to cpu-count",
)
shark_args, unknown = parser.parse_known_args()