mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[GEMM] [Tuning] Add waves_per_eu to gemm tuning (#362)
* Add waves_per_eu in the tuning space * Do not allocate tensor on device during kernel compilation step * Add breakdown elapsed time * Parallelize the post-processing step * Parallelize the profile step with --ngpus * Better timing info printout
This commit is contained in:
@@ -12,6 +12,7 @@ import triton.language as tl
|
||||
from matmul_kernel import matmul_kernel
|
||||
|
||||
from datetime import datetime
|
||||
import multiprocessing
|
||||
|
||||
|
||||
def get_full_tuning_space():
|
||||
@@ -26,6 +27,7 @@ def get_full_tuning_space():
|
||||
# But keep this explicit so that we do not forget we may need to set it to
|
||||
# other values in the future
|
||||
num_stage_range = [1, 0]
|
||||
waves_per_eu_range = [0,1,2,3,4]
|
||||
|
||||
for block_m in block_mn_range:
|
||||
for block_n in block_mn_range:
|
||||
@@ -34,7 +36,8 @@ def get_full_tuning_space():
|
||||
for group_m in group_m_range:
|
||||
for split_k in split_k_range:
|
||||
for num_stages in num_stage_range:
|
||||
configs.append({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m, 'SPLIT_K': split_k, 'num_warps': num_warps, 'num_stages': num_stages})
|
||||
for waves_per_eu in waves_per_eu_range:
|
||||
configs.append({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m, 'SPLIT_K': split_k, 'num_warps': num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu})
|
||||
|
||||
return configs
|
||||
|
||||
@@ -100,55 +103,76 @@ def read_config(config):
|
||||
split_k = config.get('SPLIT_K')
|
||||
num_warps = config.get('num_warps')
|
||||
num_stages = config.get('num_stages')
|
||||
return block_m, block_n, block_k, group_m, split_k, num_warps, num_stages
|
||||
waves_per_eu = config.get('waves_per_eu')
|
||||
return block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu
|
||||
|
||||
|
||||
def gen_kernel_and_configStr_from_config(M, N, K, config):
|
||||
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages = read_config(config)
|
||||
configStr = f"M{M}_N{N}_K{K}_BM{block_m}_BN{block_n}_BK{block_k}_GM{group_m}_SK{split_k}_nW{num_warps}_nS{num_stages}"
|
||||
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu = read_config(config)
|
||||
configStr = f"M{M}_N{N}_K{K}_BM{block_m}_BN{block_n}_BK{block_k}_GM{group_m}_SK{split_k}_nW{num_warps}_nS{num_stages}_EU{waves_per_eu}"
|
||||
|
||||
matmul_def_str = f"""
|
||||
def matmul_{configStr}(a, b, c):
|
||||
M, K = a.shape
|
||||
K, N = b.shape
|
||||
def matmul_{configStr}(a, b, c, M, N, K, am, ak, bk, bn, cm, cn, warmup=False):
|
||||
#M, K = a.shape
|
||||
#K, N = b.shape
|
||||
grid = triton.cdiv(M, {block_m}) * triton.cdiv(N, {block_n}), {split_k}
|
||||
print(f'config: matmul_kernel_{configStr}')
|
||||
matmul_kernel_{configStr}[grid](
|
||||
a, b, c,
|
||||
M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
BLOCK_SIZE_M = {block_m},
|
||||
BLOCK_SIZE_N = {block_n},
|
||||
BLOCK_SIZE_K = {block_k},
|
||||
GROUP_SIZE_M = {group_m},
|
||||
SPLIT_K = {split_k},
|
||||
num_warps = {num_warps},
|
||||
num_stages = {num_stages}
|
||||
)
|
||||
if warmup:
|
||||
matmul_kernel_{configStr}.warmup(
|
||||
torch.float16, torch.float16, torch.float16,
|
||||
M, N, K,
|
||||
am, ak, bk, bn, cm, cn,
|
||||
BLOCK_SIZE_M = {block_m},
|
||||
BLOCK_SIZE_N = {block_n},
|
||||
BLOCK_SIZE_K = {block_k},
|
||||
GROUP_SIZE_M = {group_m},
|
||||
SPLIT_K = {split_k},
|
||||
num_warps = {num_warps},
|
||||
num_stages = {num_stages},
|
||||
waves_per_eu = {waves_per_eu},
|
||||
grid=(1,)
|
||||
)
|
||||
else:
|
||||
matmul_kernel_{configStr}[grid](
|
||||
a, b, c,
|
||||
M, N, K,
|
||||
am, ak, bk, bn, cm, cn,
|
||||
BLOCK_SIZE_M = {block_m},
|
||||
BLOCK_SIZE_N = {block_n},
|
||||
BLOCK_SIZE_K = {block_k},
|
||||
GROUP_SIZE_M = {group_m},
|
||||
SPLIT_K = {split_k},
|
||||
num_warps = {num_warps},
|
||||
num_stages = {num_stages},
|
||||
waves_per_eu = {waves_per_eu}
|
||||
)
|
||||
return c
|
||||
|
||||
def try_config_{configStr}(M, N, K, dtype):
|
||||
a = torch.randn((M, K), device='cuda', dtype=dtype)
|
||||
b = torch.randn((K, N), device='cuda', dtype=dtype)
|
||||
c = torch.zeros((M, N), device=a.device, dtype=a.dtype)
|
||||
def try_config_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, dtype):
|
||||
#a = torch.randn((M, K), device='cuda', dtype=dtype)
|
||||
#b = torch.randn((K, N), device='cuda', dtype=dtype)
|
||||
#c = torch.zeros((M, N), device=a.device, dtype=a.dtype)
|
||||
try:
|
||||
matmul_{configStr}(a, b, c)
|
||||
matmul_{configStr}(None, None, None, M, N, K, am, ak, bk, bn, cm, cn, True)
|
||||
except Exception:
|
||||
print(f'invalid config {configStr}')
|
||||
"""
|
||||
return configStr, matmul_def_str
|
||||
|
||||
## Open a file generated_kernelMNK.py and generate
|
||||
## Open {ngpus} files
|
||||
## generated_kernelMNK-0.py, generated_kernelMNK-1.py, ..., generated_kernelMNK-{ngpus-1}.py
|
||||
## and generate
|
||||
## 1. matmul kernels of all configs
|
||||
## 2. wrapper function matmul to invoke all the generated kernels
|
||||
## 3. Another wraper function try_config to invoke matmul function
|
||||
## 4. test_gemm to invoke
|
||||
## 4.1 run try_config in parallel
|
||||
## 4.2 matmul in a loop of 10 iterations
|
||||
def generate_kernel(M, N, K, configs):
|
||||
f_kernel = open(f'generated_kernel{M}{N}{K}.py', 'w')
|
||||
def generate_kernel(M, N, K, configs, ngpus):
|
||||
filenames = []
|
||||
for fi in range(ngpus):
|
||||
filenames.append(f"generated_kernel{M}{N}{K}-{fi}.py")
|
||||
f_kernel = [open(path, 'w') for path in filenames]
|
||||
|
||||
### write imports
|
||||
import_str = """import torch
|
||||
@@ -158,20 +182,24 @@ import argparse
|
||||
import sys
|
||||
import multiprocessing
|
||||
"""
|
||||
f_kernel.write(import_str + "\n")
|
||||
for fi in range(ngpus):
|
||||
f_kernel[fi].write(import_str + "\n")
|
||||
|
||||
### write definitions of matmul_kernel_xxx
|
||||
### and matmul_xxx and try_config
|
||||
with open("matmul_kernel.py") as file:
|
||||
matmul_kernel_code = file.read();
|
||||
idx = 0
|
||||
for config in configs:
|
||||
file_idx = idx % ngpus
|
||||
configStr, matmul_def_str = gen_kernel_and_configStr_from_config(M, N, K, config)
|
||||
## Copy the matmul_kernel with name replaced
|
||||
matmul_kernel_config = matmul_kernel_code.replace("matmul_kernel", f"matmul_kernel_{configStr}")
|
||||
matmul_kernel_config = matmul_kernel_config.replace("import triton.language as tl", "")
|
||||
matmul_kernel_config = matmul_kernel_config.replace("import triton", "")
|
||||
f_kernel.write(matmul_kernel_config + "\n\n")
|
||||
f_kernel.write(matmul_def_str + "\n")
|
||||
f_kernel[file_idx].write(matmul_kernel_config + "\n\n")
|
||||
f_kernel[file_idx].write(matmul_def_str + "\n")
|
||||
idx += 1
|
||||
|
||||
### write test_gemm
|
||||
# pre string
|
||||
@@ -180,25 +208,42 @@ import multiprocessing
|
||||
a = torch.randn((M, K), device='cuda', dtype=dtype)
|
||||
b = torch.randn((K, N), device='cuda', dtype=dtype)
|
||||
c = torch.zeros((M, N), device=a.device, dtype=a.dtype)
|
||||
task_args = (M, N, K, dtype)
|
||||
task_args = (M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1), dtype)
|
||||
|
||||
if num_threads > 1:
|
||||
"""
|
||||
f_kernel.write(test_gemm_pre_str + "\n")
|
||||
for fi in range(ngpus):
|
||||
f_kernel[fi].write(test_gemm_pre_str + "\n")
|
||||
|
||||
# warm up call of all matmul functions in parallel
|
||||
idx = 0
|
||||
for config in configs:
|
||||
configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config)
|
||||
task_str = f" thread_pool.apply_async(try_config_{configStr}, args=task_args)\n"
|
||||
f_kernel.write(task_str)
|
||||
task_str = f" thread_pool.apply_async(try_config_{configStr}, args=task_args)\n"
|
||||
f_kernel[idx % ngpus].write(task_str)
|
||||
idx += 1
|
||||
|
||||
threadpool_str = """
|
||||
thread_pool.close()
|
||||
thread_pool.join()
|
||||
else:"""
|
||||
for fi in range(ngpus):
|
||||
f_kernel[fi].write(threadpool_str)
|
||||
# call all matmul_xxx functions
|
||||
idx = 0
|
||||
for config in configs:
|
||||
configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config)
|
||||
matmul_call_str = f"""
|
||||
for i in range(10):
|
||||
d = matmul_{configStr}(a, b, c)"""
|
||||
f_kernel.write(matmul_call_str + "\n")
|
||||
for i in range(10):
|
||||
d = matmul_{configStr}(a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1))"""
|
||||
f_kernel[idx % ngpus].write(matmul_call_str + "\n")
|
||||
idx += 1
|
||||
# post string
|
||||
f_kernel.write(" return d\n")
|
||||
for fi in range(ngpus):
|
||||
f_kernel[fi].write(" return d\n")
|
||||
|
||||
### def main and call test_gemm
|
||||
def_main_str = """
|
||||
@@ -211,34 +256,69 @@ def main():
|
||||
numThreads = args.n
|
||||
"""
|
||||
test_gemm_call_str = f'test_gemm({M}, {N}, {K}, torch.float16, numThreads)'
|
||||
f_kernel.write(def_main_str)
|
||||
f_kernel.write(test_gemm_call_str + "\n\n")
|
||||
f_kernel.write("""if __name__ == '__main__':
|
||||
sys.exit(main())""")
|
||||
f_kernel.close()
|
||||
for fi in range(ngpus):
|
||||
f_kernel[fi].write(def_main_str)
|
||||
f_kernel[fi].write(test_gemm_call_str + "\n\n")
|
||||
f_kernel[fi].write("""if __name__ == '__main__':
|
||||
sys.exit(main())""")
|
||||
f_kernel[fi].close()
|
||||
|
||||
def extract_kernel_time(M, N, K, config, gpuid):
|
||||
configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config)
|
||||
parse_result_cmd = f'sed -n \'/matmul_kernel_{configStr}/p\' results-{gpuid}.csv | awk -F \',\' \'{{print $NF}}\' | tail -n1'
|
||||
parsed_outputs = run_bash_command(parse_result_cmd)
|
||||
return config, parsed_outputs
|
||||
|
||||
|
||||
def tune_gemm_config(M, N, K, configs):
|
||||
def profile_batch_kernels(M, N, K, gpuid):
|
||||
os.environ['ROCR_VISIBLE_DEVICES'] = str(gpuid)
|
||||
run_bash_command(f"rocprof --stats -o results-{gpuid}.csv python generated_kernel{M}{N}{K}-{gpuid}.py")
|
||||
|
||||
|
||||
def tune_gemm_config(M, N, K, configs, verbose=False, num_threads=16, ngpus = 1):
|
||||
## Generate kernel out of all configs
|
||||
generate_kernel(M, N, K, configs)
|
||||
generate_kernel(M, N, K, configs, ngpus)
|
||||
|
||||
## remove any compiled kernel in the cache
|
||||
run_bash_command("rm -rf ~/.triton/cache")
|
||||
|
||||
## precompile the kernels in parallel
|
||||
## TODO: parameterize numThreads at this level
|
||||
run_bash_command(f"python generated_kernel{M}{N}{K}.py -n 16")
|
||||
start_time = datetime.now()
|
||||
for fi in range(ngpus):
|
||||
run_bash_command(f"python generated_kernel{M}{N}{K}-{fi}.py -n 32")
|
||||
compile_end = datetime.now()
|
||||
compile_time = compile_end - start_time
|
||||
if verbose:
|
||||
print(f"compile time: {compile_time}")
|
||||
|
||||
## profile generated kernels
|
||||
run_bash_command(f"rocprof --stats python generated_kernel{M}{N}{K}.py")
|
||||
running = [multiprocessing.Process(target=profile_batch_kernels, args=(M,N,K,fi)) for fi in range(ngpus)]
|
||||
for p in running:
|
||||
p.start()
|
||||
for p in running:
|
||||
p.join()
|
||||
|
||||
profile_end = datetime.now()
|
||||
profile_time = profile_end - compile_end
|
||||
if verbose:
|
||||
print(f"profile time: {profile_time}")
|
||||
|
||||
## post process results.csv to get the best config and minTime
|
||||
## TODO: process the file in parallel
|
||||
minTime = 1024 * 1024 * 1024
|
||||
thread_pool = multiprocessing.Pool(processes=num_threads)
|
||||
tasks = []
|
||||
idx = 0
|
||||
for config in configs:
|
||||
configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config)
|
||||
parse_result_cmd = f'sed -n \'/matmul_kernel_{configStr}/p\' results.csv | awk -F \',\' \'{{print $NF}}\' | tail -n1'
|
||||
parsed_outputs = run_bash_command(parse_result_cmd)
|
||||
file_idx = idx % ngpus
|
||||
tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, file_idx))]
|
||||
idx += 1
|
||||
thread_pool.close()
|
||||
thread_pool.join()
|
||||
|
||||
for task in tasks:
|
||||
config, parsed_outputs = task.get()
|
||||
if parsed_outputs:
|
||||
min_us = int(parsed_outputs[0]) / 1000
|
||||
if min_us < minTime:
|
||||
@@ -247,10 +327,14 @@ def tune_gemm_config(M, N, K, configs):
|
||||
else:
|
||||
min_us = -1
|
||||
print(f"invalid config: SIZE {M} {N} {K}: {config}")
|
||||
return minTime, bestConfig
|
||||
post_end = datetime.now()
|
||||
post_time = post_end - profile_end
|
||||
if verbose:
|
||||
print(f"post procesing time: {post_time}")
|
||||
return minTime, bestConfig, compile_time, profile_time, post_time
|
||||
|
||||
|
||||
def matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages):
|
||||
def matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu):
|
||||
# Check constraints.
|
||||
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
|
||||
assert a.is_contiguous(), "Matrix A must be contiguous"
|
||||
@@ -276,19 +360,20 @@ def matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_
|
||||
SPLIT_K = split_k,
|
||||
num_warps = num_warps,
|
||||
num_stages = num_stages,
|
||||
waves_per_eu = waves_per_eu,
|
||||
)
|
||||
return c
|
||||
|
||||
|
||||
def test_correctness(M, N, K, config, verbose, datatype = torch.float16):
|
||||
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages = read_config(config)
|
||||
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu = read_config(config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
a = torch.randn((M, K), device='cuda', dtype=datatype)
|
||||
b = torch.randn((K, N), device='cuda', dtype=datatype)
|
||||
# Allocates output.
|
||||
c = torch.zeros((M, N), device=a.device, dtype=a.dtype)
|
||||
triton_output = matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages)
|
||||
triton_output = matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu)
|
||||
torch_output = torch.matmul(a, b)
|
||||
#print(f"triton_output={triton_output}")
|
||||
#print(f"torch_output={torch_output}")
|
||||
@@ -322,11 +407,13 @@ def parse_args():
|
||||
parser.add_argument("-m", type=int, default=0)
|
||||
parser.add_argument("-n", type=int, default=0)
|
||||
parser.add_argument("-k", type=int, default=0)
|
||||
parser.add_argument("--ngpus", type=int, default=1, help='number of GPUs used in the profiling step')
|
||||
parser.add_argument("--gemm_size_file", type=str, default="", help='yaml file to indicate matrix size')
|
||||
parser.add_argument("--tuning_results_file", type=str, default=get_default_tuning_result_filename(), help='yaml file to store tuning results')
|
||||
parser.add_argument("--keep", action='store_true', default=False, help='keep generated files')
|
||||
parser.add_argument("--compare", action='store_true', default=False, help="Whether check result correctness")
|
||||
parser.add_argument("--compare_wo_tuning", action='store_true', default=False, help="Whether check result correctness")
|
||||
parser.add_argument("--time_breakdown", action='store_true', default=False, help="Show detailed time breakdown of each step during the tuning")
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
@@ -337,6 +424,7 @@ def main():
|
||||
matrix_size_file = args.gemm_size_file
|
||||
tuning_output_file = args.tuning_results_file
|
||||
keepTmp = args.keep
|
||||
ngpus = args.ngpus
|
||||
|
||||
mnks = []
|
||||
## TODO: make it more robust to get user input
|
||||
@@ -369,9 +457,11 @@ def main():
|
||||
configs_full = get_full_tuning_space()
|
||||
|
||||
start_time = datetime.now()
|
||||
print(f"Tuning starts at: {start_time}")
|
||||
|
||||
f_results = open(tuning_output_file, 'w')
|
||||
for (M, N, K) in mnks:
|
||||
start_local_time = datetime.now()
|
||||
## Obtain a pruned tuning space according to gemm size
|
||||
pruned_configs = prune_configs(M, N, K, configs_full)
|
||||
|
||||
@@ -379,7 +469,7 @@ def main():
|
||||
print(f"{size_str} nConfigs: {len(pruned_configs)}", end=" ", flush=True)
|
||||
|
||||
## The main tuning funtion for one gemm size
|
||||
minTime, bestConfig = tune_gemm_config(M, N, K, pruned_configs)
|
||||
minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config(M, N, K, pruned_configs, ngpus = ngpus, verbose=args.time_breakdown)
|
||||
|
||||
## post processing the numbers
|
||||
perf_tflops = lambda us: 2 * M * N * K * 1e-12 / (us * 1e-6)
|
||||
@@ -401,9 +491,10 @@ def main():
|
||||
|
||||
## remove generated files if asked to
|
||||
if not keepTmp:
|
||||
os.remove(f"generated_kernel{M}{N}{K}.py")
|
||||
for f in glob.glob("results.*"):
|
||||
os.remove(f)
|
||||
for fi in range(ngpus):
|
||||
os.remove(f"generated_kernel{M}{N}{K}-{fi}.py")
|
||||
for f in glob.glob(f"results-{fi}.*"):
|
||||
os.remove(f)
|
||||
|
||||
## Check correctness if asked to
|
||||
if args.compare:
|
||||
@@ -412,11 +503,15 @@ def main():
|
||||
else:
|
||||
print("")
|
||||
|
||||
end_local_time = datetime.now()
|
||||
print(f">>> Elapsed time: {end_local_time - start_local_time} = {compile_time} (compile) + {profile_time} (profile) + {post_time} (post processing)")
|
||||
|
||||
f_results.close()
|
||||
|
||||
end_time = datetime.now()
|
||||
tuning_time = end_time - start_time
|
||||
print(f"Tuning time (h:m:s): {tuning_time}")
|
||||
print(f"Tuning ends at: {end_time}")
|
||||
print(f"Total tuning time (h:m:s): {tuning_time}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user