[GEMM] Add script to run one tuning config (#419)

The script runs one given config for debug purposes.
This commit is contained in:
Alexander Efimov
2023-12-08 03:12:03 +03:00
committed by GitHub
parent 1d6b919897
commit e19b5fd6bc
2 changed files with 101 additions and 0 deletions

View File

@@ -69,3 +69,26 @@ On some node, I saw the following runtime error
```
It's hard to reproduce the error. **Needs further investigation**
- https://github.com/ROCmSoftwarePlatform/frameworks-internal/issues/6011
# One config running script
`one_config.py` is a script that runs one given matmul config.
It is an interface to `tune_gemm.py` functionality and could be used for triton debugging.
### Usage
This script supports two methods to specify configuration parameters.
Variant 1: Separate command line attributes.
```bash
python one_config.py -m 256 -n 256 -k 256 --block_m 64 --block_n 64 --block_k 64 --group_m 1 --split_k 2 --num_warps 2 --num_stages 0 --waves_per_eu 0
```
Variant 2: one-line config description.
This is how configs are printed by `tune_gemm.py` script
```bash
python one_config.py --config_str M16_N8_K128_BM64_BN64_BK64_GM1_SK2_nW2_nS0_EU0
```

View File

@@ -0,0 +1,78 @@
"""
Script for running one Matrix Multiplication kernel config at a time
"""
import argparse
import re
import sys
import tune_gemm
def parse_args():
parser = argparse.ArgumentParser(
prog="check corectness of particular config for tuning gemm script",
allow_abbrev=False,
)
parser.add_argument("-m", type=int, default=0)
parser.add_argument("-n", type=int, default=0)
parser.add_argument("-k", type=int, default=0)
parser.add_argument("--block_m", type=int, default=0)
parser.add_argument("--block_n", type=int, default=0)
parser.add_argument("--block_k", type=int, default=0)
parser.add_argument("--group_m", type=int, default=0)
parser.add_argument("--split_k", type=int, default=0)
parser.add_argument("--num_warps", type=int, default=0)
parser.add_argument("--num_stages", type=int, default=0)
parser.add_argument("--waves_per_eu", type=int, default=0)
parser.add_argument("--config_str", type=str, default="", help="can take from gemm_tune.py script output, looks like M16_N8_K128_BM64_BN64_BK64_GM1_SK2_nW2_nS0_EU0")
args = parser.parse_args()
return args
def parse_config(cfg_str):
values = cfg_str.split("_")
config_name = {"M": "M",
"N": "N",
"K": "K",
"BM": "BLOCK_SIZE_M",
"BN": "BLOCK_SIZE_N",
"BK": "BLOCK_SIZE_K",
"GM": "GROUP_SIZE_M",
"SK": "SPLIT_K",
"nW": "num_warps",
"nS": "num_stages",
"EU": "waves_per_eu",
}
config = {}
for val in values:
match = re.search("([a-zA-Z]*)([0-9]*)", val)
if match:
cfg_field_name = config_name[match.group(1)]
config[cfg_field_name] = int(match.group(2))
return config
def main():
args = parse_args()
if args.config_str:
config = parse_config(args.config_str)
else:
config = {"M": args.m,
"N": args.n,
"K": args.k,
"BLOCK_SIZE_M": args.block_m,
"BLOCK_SIZE_N": args.block_n,
"BLOCK_SIZE_K": args.block_k,
"GROUP_SIZE_M": args.group_m,
"SPLIT_K": args.split_k,
"num_warps": args.num_warps,
"num_stages": args.num_stages,
"waves_per_eu": args.waves_per_eu,
}
tune_gemm.test_correctness(config["M"], config["N"], config["K"], config, verbose=True)
if __name__ == "__main__":
sys.exit(main())