Update model annotation tool (#361)

Usage:
with create_context() as ctx:
  module = model_annotation(ctx, input_contents=..., config_path=..., search_op=...)

Example:
The example is to annotate the minilm model with GPU config files.
python model_annotation.py /nodclouddata/vivian/minilm_model/model.mlir /nodclouddata/vivian/minilm_model/model_config.json
This commit is contained in:
yzhang93
2022-09-23 15:44:51 -07:00
committed by GitHub
parent b9c8985047
commit 587d74b449

View File

@@ -12,22 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import json
import os
from typing import List, Dict
import sys
from typing import Dict, List
from iree.compiler import ir
from iree.compiler.transforms import ireec as ireec_trans
MATMUL_OP_NAMES = set(
["linalg.matmul", "linalg.batch_matmul", "mhlo.dot", "mhlo.dot_general"]
)
idx = 0
def model_annotation(
ctx: ir.Context, *, input_contents: str, config_path: str
ctx: ir.Context,
*,
input_contents: str,
config_path: str,
search_op: str = "matmul",
):
if os.path.isfile(input_contents):
with open(input_contents, "rb") as f:
@@ -41,21 +40,35 @@ def model_annotation(
# The Python API does not expose a general walk() function, so we just
# do it ourselves.
walk_children(module.operation, configs)
walk_children(module.operation, configs, 0, search_op)
if not module.operation.verify():
raise RuntimeError("Modified program does not verify!")
# More efficient than: print(module)
# - Disables verification (already done above)
# - Writes as binary, avoiding costly unicode conversions
sys.stdout.buffer.write(
module.operation.get_asm(assume_verified=True, binary=True)
)
return module
def walk_children(op: ir.Operation, configs: List[Dict]):
def walk_children(
op: ir.Operation, configs: List[Dict], idx: int, search_op: str
):
if search_op == "matmul":
op_names = ["linalg.matmul", "mhlo.dot"]
elif search_op == "bmm":
op_names = ["linalg.batch_matmul", "mhlo.dot_general"]
elif search_op == "conv":
op_names = ["mhlo.convolution", "linalg.conv_2d_nhwc_hwcf"]
elif search_op == "all":
op_names = [
"mhlo.dot",
"mhlo.dot_general",
"mhlo.convolution",
"linalg.matmul",
"linalg.batch_matmul",
"linalg.conv_2d_nhwc_hwcf",
]
else:
raise ValueError(f"{search_op} op is not tunable.")
for region in op.regions:
for block in region.blocks:
for child_op in block.operations:
@@ -63,30 +76,32 @@ def walk_children(op: ir.Operation, configs: List[Dict]):
# 'operation' and 'name' attributes.
if isinstance(child_op, ir.OpView):
child_op = child_op.operation
if child_op.name in MATMUL_OP_NAMES:
global idx
(
tile_sizes,
pipeline,
workgroup_size,
split_k,
pipeline_depth,
) = parse_config(configs[idx])
add_compilation_info(
child_op,
tile_sizes=tile_sizes,
pipeline=pipeline,
workgroup_size=workgroup_size,
pipeline_depth=pipeline_depth,
)
if split_k:
add_split_k(child_op, split_k)
if child_op.name in op_names and idx < len(configs):
add_attributes(child_op, configs[idx])
idx = idx + 1
print(f"Updated op {child_op}", file=sys.stderr)
walk_children(child_op, configs)
walk_children(child_op, configs, idx, search_op)
def add_attributes(op: ir.Operation, config: Dict):
(
tile_sizes,
pipeline,
workgroup_size,
split_k,
pipeline_depth,
) = parse_config(config)
add_compilation_info(
op,
tile_sizes=tile_sizes,
pipeline=pipeline,
workgroup_size=workgroup_size,
pipeline_depth=pipeline_depth,
)
if split_k:
add_attribute_by_name(op, "iree_flow_split_k", split_k)
def parse_config(config: Dict):
@@ -145,9 +160,9 @@ def add_compilation_info(
op.attributes["compilation_info"] = attr
def add_split_k(op: ir.Operation, k: int):
attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), k)
op.attributes["iree_flow_split_k"] = attr
def add_attribute_by_name(op: ir.Operation, name: str, val: int):
attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), val)
op.attributes[name] = attr
def create_context() -> ir.Context:
@@ -159,6 +174,14 @@ def create_context() -> ir.Context:
if __name__ == "__main__":
with create_context() as ctx:
model_annotation(
ctx, input_contents=sys.argv[1], config_path=sys.argv[2]
module = model_annotation(
ctx,
input_contents=sys.argv[1],
config_path=sys.argv[2],
search_op="all",
)
mlir_str = str(module)
filename = "tuned_model.mlir"
with open(filename, "w") as f:
f.write(mlir_str)
print(f"Saved mlir in {filename}.")