mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Modify model annotation tool to walk through ops by shape (#692)
This commit is contained in:
@@ -22,7 +22,7 @@ from shark.model_annotation import model_annotation
|
||||
with create_context() as ctx:
|
||||
module = model_annotation(ctx, input_contents=..., config_path=..., search_op=...)
|
||||
2. Run model_annotation.py directly
|
||||
python model_annotation.py path_to_original_mlir path_to_config_file
|
||||
python model_annotation.py -model path_to_original_mlir -config_path path_to_config_file
|
||||
"""
|
||||
|
||||
import json
|
||||
@@ -39,21 +39,18 @@ def model_annotation(
|
||||
*,
|
||||
input_contents: str,
|
||||
config_path: str,
|
||||
search_op: str = "matmul",
|
||||
search_op: str,
|
||||
):
|
||||
if os.path.isfile(input_contents):
|
||||
with open(input_contents, "rb") as f:
|
||||
input_contents = f.read()
|
||||
|
||||
module = ir.Module.parse(input_contents)
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
data = json.load(f)
|
||||
configs = data["options"]
|
||||
configs = load_model_configs(config_path)
|
||||
|
||||
# The Python API does not expose a general walk() function, so we just
|
||||
# do it ourselves.
|
||||
walk_children(module.operation, configs, 0, search_op)
|
||||
walk_children(module.operation, configs, search_op)
|
||||
|
||||
if not module.operation.verify():
|
||||
raise RuntimeError("Modified program does not verify!")
|
||||
@@ -61,15 +58,49 @@ def model_annotation(
|
||||
return module
|
||||
|
||||
|
||||
def walk_children(
|
||||
op: ir.Operation, configs: List[Dict], idx: int, search_op: str
|
||||
):
|
||||
def load_model_configs(config_path: str):
|
||||
config = {}
|
||||
with open(config_path, "r") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
|
||||
if "identifier" not in data.keys():
|
||||
continue
|
||||
if data["identifier"] == "matmul":
|
||||
matrix_size = [data["m"], data["n"], data["k"]]
|
||||
elif data["identifier"] == "bmm":
|
||||
matrix_size = [data["b"], data["m"], data["n"], data["k"]]
|
||||
elif data["identifier"] == "generic":
|
||||
matrix_size = [1, data["b"], data["m"], data["n"], data["k"]]
|
||||
elif data["identifier"] == "conv":
|
||||
matrix_size = [
|
||||
data["n"],
|
||||
data["ih"],
|
||||
data["iw"],
|
||||
data["c"],
|
||||
data["kh"],
|
||||
data["kw"],
|
||||
data["f"],
|
||||
data["oh"],
|
||||
data["ow"],
|
||||
data["d"],
|
||||
data["s"],
|
||||
data["p"],
|
||||
]
|
||||
config[shape_list_to_string(matrix_size)] = data
|
||||
f.close()
|
||||
return config
|
||||
|
||||
|
||||
def walk_children(op: ir.Operation, configs: List[Dict], search_op: str):
|
||||
if search_op == "matmul":
|
||||
op_names = ["linalg.matmul", "mhlo.dot"]
|
||||
elif search_op == "bmm":
|
||||
op_names = ["linalg.batch_matmul", "mhlo.dot_general"]
|
||||
elif search_op == "conv":
|
||||
op_names = ["mhlo.convolution", "linalg.conv_2d_nhwc_hwcf"]
|
||||
elif search_op == "generic":
|
||||
op_names = ["linalg.generic"]
|
||||
elif search_op == "all":
|
||||
op_names = [
|
||||
"mhlo.dot",
|
||||
@@ -78,6 +109,7 @@ def walk_children(
|
||||
"linalg.matmul",
|
||||
"linalg.batch_matmul",
|
||||
"linalg.conv_2d_nhwc_hwcf",
|
||||
"linalg.generic",
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"{search_op} op is not tunable.")
|
||||
@@ -89,37 +121,167 @@ def walk_children(
|
||||
# 'operation' and 'name' attributes.
|
||||
if isinstance(child_op, ir.OpView):
|
||||
child_op = child_op.operation
|
||||
if child_op.name in op_names and idx < len(configs):
|
||||
add_attributes(child_op, configs[idx])
|
||||
idx = idx + 1
|
||||
if child_op.name in op_names:
|
||||
if child_op.name == "linalg.generic":
|
||||
# This is for generic op that has contractionOpInterface
|
||||
# which is basically einsum("mk,bkn->bmn")
|
||||
op_result = str(child_op.results[0])
|
||||
op_iterator = str(
|
||||
child_op.attributes["iterator_types"]
|
||||
)
|
||||
if len(child_op.operands) != 3:
|
||||
continue
|
||||
if "reduction" not in op_iterator:
|
||||
continue
|
||||
if (
|
||||
"arith.addf" not in op_result
|
||||
or "arith.mulf" not in op_result
|
||||
):
|
||||
continue
|
||||
if "arith.subf" in op_result:
|
||||
continue
|
||||
|
||||
child_op_shape = get_op_shape(child_op, search_op)
|
||||
if (
|
||||
child_op_shape in configs.keys()
|
||||
and configs[child_op_shape]["options"][0] != None
|
||||
):
|
||||
add_attributes(
|
||||
child_op, configs[child_op_shape]["options"][0]
|
||||
)
|
||||
print(f"Updated op {child_op}", file=sys.stderr)
|
||||
walk_children(child_op, configs, idx, search_op)
|
||||
|
||||
walk_children(child_op, configs, search_op)
|
||||
|
||||
|
||||
def add_attributes(op: ir.Operation, config: Dict):
|
||||
(
|
||||
tile_sizes,
|
||||
pipeline,
|
||||
workgroup_size,
|
||||
split_k,
|
||||
pipeline_depth,
|
||||
) = parse_config(config)
|
||||
def get_op_shape(op: ir.Operation, search_op: str):
|
||||
shape_list = []
|
||||
if search_op in ["generic", "all"]:
|
||||
if op.name in ["linalg.generic"]:
|
||||
input1 = str(op.operands[0].type)
|
||||
input2 = str(op.operands[1].type)
|
||||
m = input1.split("tensor<")[1].split("x")[0]
|
||||
b = input2.split("tensor<")[1].split("x")[0]
|
||||
k = input2.split("tensor<")[1].split("x")[1]
|
||||
n = input2.split("tensor<")[1].split("x")[2]
|
||||
shape_list = [1, int(b), int(m), int(n), int(k)]
|
||||
|
||||
add_compilation_info(
|
||||
op,
|
||||
tile_sizes=tile_sizes,
|
||||
pipeline=pipeline,
|
||||
workgroup_size=workgroup_size,
|
||||
pipeline_depth=pipeline_depth,
|
||||
)
|
||||
if search_op in ["matmul", "all"]:
|
||||
if op.name in ["mhlo.dot"]:
|
||||
op_result = str(op.results[0])
|
||||
m = op_result.split("tensor<")[1].split("x")[0]
|
||||
k = op_result.split("tensor<")[1].split("x")[1]
|
||||
n = op_result.split("tensor<")[2].split("x")[1]
|
||||
shape_list = [int(m), int(n), int(k)]
|
||||
elif op.name in ["linalg.matmul"]:
|
||||
op_result = str(op.results[0]).split("ins(")[1]
|
||||
m = op_result.split("tensor<")[1].split("x")[0]
|
||||
k = op_result.split("tensor<")[1].split("x")[1]
|
||||
n = op_result.split("tensor<")[2].split("x")[1]
|
||||
shape_list = [int(m), int(n), int(k)]
|
||||
|
||||
if split_k:
|
||||
add_attribute_by_name(op, "iree_flow_split_k", split_k)
|
||||
if search_op in ["bmm", "all"]:
|
||||
if op.name in ["mhlo.dot_general"]:
|
||||
op_result = str(op.results[0])
|
||||
b = op_result.split("tensor<")[1].split("x")[1]
|
||||
m = op_result.split("tensor<")[1].split("x")[2]
|
||||
k = op_result.split("tensor<")[1].split("x")[3]
|
||||
n = op_result.split("tensor<")[3].split("x")[3]
|
||||
shape_list = [int(b), int(m), int(n), int(k)]
|
||||
elif op.name in ["linalg.batch_matmul"]:
|
||||
op_result = str(op.results[0]).split("ins(")[1]
|
||||
b = op_result.split("tensor<")[1].split("x")[0]
|
||||
m = op_result.split("tensor<")[1].split("x")[1]
|
||||
k = op_result.split("tensor<")[1].split("x")[2]
|
||||
n = op_result.split("tensor<")[3].split("x")[2]
|
||||
shape_list = [int(b), int(m), int(n), int(k)]
|
||||
|
||||
if search_op in ["conv", "all"]:
|
||||
if op.name in ["mhlo.convolution"]:
|
||||
op_result = str(op.results[0])
|
||||
dilation = (
|
||||
str(op.attributes["rhs_dilation"])
|
||||
.split("dense<")[1]
|
||||
.split(">")[0]
|
||||
)
|
||||
stride = (
|
||||
str(op.attributes["window_strides"])
|
||||
.split("dense<")[1]
|
||||
.split(">")[0]
|
||||
)
|
||||
pad = (
|
||||
str(op.attributes["padding"]).split("dense<")[1].split(">")[0]
|
||||
)
|
||||
n = op_result.split("tensor<")[1].split("x")[0]
|
||||
ih = op_result.split("tensor<")[1].split("x")[1]
|
||||
iw = op_result.split("tensor<")[1].split("x")[2]
|
||||
c = op_result.split("tensor<")[1].split("x")[3]
|
||||
kh = op_result.split("tensor<")[2].split("x")[0]
|
||||
kw = op_result.split("tensor<")[2].split("x")[1]
|
||||
f = op_result.split("tensor<")[2].split("x")[3]
|
||||
oh = op_result.split("tensor<")[3].split("x")[1]
|
||||
ow = op_result.split("tensor<")[3].split("x")[2]
|
||||
shape_list = [
|
||||
int(n),
|
||||
int(ih),
|
||||
int(iw),
|
||||
int(c),
|
||||
int(kh),
|
||||
int(kw),
|
||||
int(f),
|
||||
int(oh),
|
||||
int(ow),
|
||||
int(dilation),
|
||||
int(stride),
|
||||
int(pad),
|
||||
]
|
||||
|
||||
elif op.name in ["linalg.conv_2d_nhwc_hwcf"]:
|
||||
op_result = str(op.results[0]).split("ins(")[1]
|
||||
dilation = (
|
||||
str(op.attributes["dilations"])
|
||||
.split("dense<")[1]
|
||||
.split(">")[0]
|
||||
)
|
||||
stride = (
|
||||
str(op.attributes["strides"]).split("dense<")[1].split(">")[0]
|
||||
)
|
||||
pad = 0
|
||||
n = op_result.split("tensor<")[1].split("x")[0]
|
||||
ih = op_result.split("tensor<")[1].split("x")[1]
|
||||
iw = op_result.split("tensor<")[1].split("x")[2]
|
||||
c = op_result.split("tensor<")[1].split("x")[3]
|
||||
kh = op_result.split("tensor<")[2].split("x")[0]
|
||||
kw = op_result.split("tensor<")[2].split("x")[1]
|
||||
f = op_result.split("tensor<")[2].split("x")[3]
|
||||
oh = op_result.split("tensor<")[3].split("x")[1]
|
||||
ow = op_result.split("tensor<")[3].split("x")[2]
|
||||
shape_list = [
|
||||
int(n),
|
||||
int(ih),
|
||||
int(iw),
|
||||
int(c),
|
||||
int(kh),
|
||||
int(kw),
|
||||
int(f),
|
||||
int(oh),
|
||||
int(ow),
|
||||
int(dilation),
|
||||
int(stride),
|
||||
int(pad),
|
||||
]
|
||||
|
||||
shape_str = shape_list_to_string(shape_list)
|
||||
return shape_str
|
||||
|
||||
|
||||
def parse_config(config: Dict):
|
||||
def add_attributes(op: ir.Operation, config: List[Dict]):
|
||||
# Parse the config file
|
||||
split_k = None
|
||||
pipeline_depth = None
|
||||
store_stage = None
|
||||
subgroup_size = None
|
||||
|
||||
if "GPU" in config["pipeline"]:
|
||||
pipeline = (
|
||||
"LLVMGPUMatmulSimt"
|
||||
@@ -132,6 +294,10 @@ def parse_config(config: Dict):
|
||||
pipeline_depth = config["pipeline_depth"]
|
||||
if "split_k" in config.keys():
|
||||
split_k = config["split_k"]
|
||||
if "devices" in config.keys():
|
||||
devices = config["devices"]
|
||||
if "shard_sizes" in config.keys():
|
||||
shard_sizes = config["shard_sizes"]
|
||||
elif "SPIRV" in config["pipeline"]:
|
||||
pipeline = config["pipeline"]
|
||||
tile_sizes = [
|
||||
@@ -139,11 +305,17 @@ def parse_config(config: Dict):
|
||||
config["parallel_tile_sizes"],
|
||||
config["reduction_tile_sizes"],
|
||||
]
|
||||
workgroup_size = config["work_group_sizes"]
|
||||
if "vector_tile_sizes" in config.keys():
|
||||
tile_sizes += [config["vector_tile_sizes"]]
|
||||
if "window_tile_sizes" in config.keys():
|
||||
tile_sizes += [config["window_tile_sizes"]]
|
||||
workgroup_size = config["work_group_sizes"]
|
||||
if "subgroup_size" in config.keys():
|
||||
subgroup_size = config["subgroup_size"]
|
||||
if "pipeline_depth" in config.keys():
|
||||
pipeline_depth = config["pipeline_depth"]
|
||||
if "store_stage" in config.keys():
|
||||
store_stage = config["store_stage"]
|
||||
else:
|
||||
# For IREE CPU pipelines
|
||||
pipeline = config["pipeline"]
|
||||
@@ -153,40 +325,45 @@ def parse_config(config: Dict):
|
||||
config["reduction_tile_sizes"],
|
||||
]
|
||||
workgroup_size = []
|
||||
return tile_sizes, pipeline, workgroup_size, split_k, pipeline_depth
|
||||
|
||||
|
||||
def add_compilation_info(
|
||||
op: ir.Operation,
|
||||
tile_sizes: List[List[int]],
|
||||
pipeline: str,
|
||||
workgroup_size: List[int],
|
||||
pipeline_depth: int,
|
||||
):
|
||||
# We don't have a Python binding for CompilationInfo, so we just parse
|
||||
# its string form.
|
||||
if pipeline_depth:
|
||||
attr = ir.Attribute.parse(
|
||||
f"#iree_codegen.compilation_info<"
|
||||
f"lowering_config = <tile_sizes = {repr(tile_sizes)}>, "
|
||||
f"translation_info = <{pipeline} pipeline_depth = {pipeline_depth}>, "
|
||||
f"workgroup_size = {repr(workgroup_size)}>"
|
||||
)
|
||||
# Add compilation info as an attribute. We don't have a Python binding for CompilationInfo,
|
||||
# so we just parse its string form.
|
||||
if pipeline_depth != None:
|
||||
translation_info = f"{pipeline} pipeline_depth = {pipeline_depth}"
|
||||
if store_stage != None:
|
||||
translation_info += f" store_stage = {store_stage}"
|
||||
else:
|
||||
attr = ir.Attribute.parse(
|
||||
f"#iree_codegen.compilation_info<"
|
||||
f"lowering_config = <tile_sizes = {repr(tile_sizes)}>, "
|
||||
f"translation_info = <{pipeline}>, "
|
||||
f"workgroup_size = {repr(workgroup_size)}>"
|
||||
)
|
||||
translation_info = f"{pipeline}"
|
||||
|
||||
compilation_info = (
|
||||
f"#iree_codegen.compilation_info<"
|
||||
f"lowering_config = <tile_sizes = {repr(tile_sizes)}>, "
|
||||
f"translation_info = <{translation_info}>, "
|
||||
f"workgroup_size = {repr(workgroup_size)} "
|
||||
)
|
||||
|
||||
if subgroup_size != None:
|
||||
compilation_info += f", subgroup_size = {subgroup_size}>"
|
||||
else:
|
||||
compilation_info += ">"
|
||||
|
||||
attr = ir.Attribute.parse(compilation_info)
|
||||
op.attributes["compilation_info"] = attr
|
||||
|
||||
# Add other attributes if required.
|
||||
if split_k:
|
||||
add_attribute_by_name(op, "iree_flow_split_k", split_k)
|
||||
|
||||
|
||||
def add_attribute_by_name(op: ir.Operation, name: str, val: int):
|
||||
attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), val)
|
||||
op.attributes[name] = attr
|
||||
|
||||
|
||||
def shape_list_to_string(input):
|
||||
return "x".join([str(d) for d in input])
|
||||
|
||||
|
||||
def create_context() -> ir.Context:
|
||||
context = ir.Context()
|
||||
ireec_trans.register_all_dialects(context)
|
||||
@@ -195,15 +372,48 @@ def create_context() -> ir.Context:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
def path_expand(s):
|
||||
return Path(s).expanduser().resolve()
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-model",
|
||||
type=path_expand,
|
||||
default="model.mlir",
|
||||
help="Path to the input mlir file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-config_path",
|
||||
type=path_expand,
|
||||
default="best_configs.json",
|
||||
help="Path where stores the op config file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-output_path",
|
||||
type=path_expand,
|
||||
default="tuned_model.mlir",
|
||||
help="Path to save the annotated mlir file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-search_op",
|
||||
type=str,
|
||||
default="all",
|
||||
help="Op to be optimized. options are matmul, bmm, conv.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with create_context() as ctx:
|
||||
module = model_annotation(
|
||||
ctx,
|
||||
input_contents=sys.argv[1],
|
||||
config_path=sys.argv[2],
|
||||
search_op="all",
|
||||
input_contents=args.model,
|
||||
config_path=args.config_path,
|
||||
search_op=args.search_op,
|
||||
)
|
||||
mlir_str = str(module)
|
||||
filename = "tuned_model.mlir"
|
||||
with open(filename, "w") as f:
|
||||
with open(args.output_path, "w") as f:
|
||||
f.write(mlir_str)
|
||||
print(f"Saved mlir in {filename}.")
|
||||
print(f"Saved mlir in {args.output_path}.")
|
||||
|
||||
Reference in New Issue
Block a user