mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
set task_topology_max_group to cpu_count (#1594)
by default. Can be overriden with a flag of the same str
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
import iree.runtime as ireert
|
||||
import iree.compiler as ireec
|
||||
from shark.iree_utils._common import iree_device_map, iree_target_map
|
||||
from shark.iree_utils.cpu_utils import get_iree_cpu_rt_args
|
||||
from shark.iree_utils.benchmark_utils import *
|
||||
from shark.parser import shark_args
|
||||
import numpy as np
|
||||
@@ -352,6 +353,12 @@ def load_vmfb_using_mmap(
|
||||
config = ireert.Config(device=haldevice)
|
||||
else:
|
||||
config = get_iree_runtime_config(device)
|
||||
if "task" in device:
|
||||
print(
|
||||
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
|
||||
)
|
||||
for flag in get_iree_cpu_rt_args():
|
||||
ireert.flags.parse_flags(flag)
|
||||
# Now load vmfb.
|
||||
# Two scenarios we have here :-
|
||||
# 1. We either have the vmfb already saved and therefore pass the path of it.
|
||||
@@ -359,7 +366,6 @@ def load_vmfb_using_mmap(
|
||||
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
|
||||
# (This would arise if we're invoking `compile` from a SharkInference obj)
|
||||
temp_file_to_unlink = None
|
||||
|
||||
if isinstance(flatbuffer_blob_or_path, Path):
|
||||
flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__()
|
||||
if (
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import subprocess
|
||||
import platform
|
||||
from shark.parser import shark_args
|
||||
|
||||
|
||||
def get_cpu_count():
|
||||
@@ -44,4 +45,18 @@ def get_iree_cpu_args():
|
||||
error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)"
|
||||
raise Exception(error_message)
|
||||
print(f"Target triple found:{target_triple}")
|
||||
return [f"--iree-llvmcpu-target-triple={target_triple}"]
|
||||
return [
|
||||
f"--iree-llvmcpu-target-triple={target_triple}",
|
||||
]
|
||||
|
||||
|
||||
# Get iree runtime flags for cpu
|
||||
def get_iree_cpu_rt_args():
|
||||
default = get_cpu_count()
|
||||
default = default if default <= 8 else default - 2
|
||||
cpu_count = (
|
||||
default
|
||||
if shark_args.task_topology_max_group_count is None
|
||||
else shark_args.task_topology_max_group_count
|
||||
)
|
||||
return [f"--task_topology_max_group_count={cpu_count}"]
|
||||
|
||||
@@ -119,5 +119,11 @@ parser.add_argument(
|
||||
"to augment the base device allocator",
|
||||
choices=["debug", "caching"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_topology_max_group_count",
|
||||
type=str,
|
||||
default=None,
|
||||
help="passthrough flag for the iree flag of the same name. If None, defaults to cpu-count",
|
||||
)
|
||||
|
||||
shark_args, unknown = parser.parse_known_args()
|
||||
|
||||
Reference in New Issue
Block a user