[GEMM] [Tuning] Add waves_per_eu to gemm tuning (#362)

* Add waves_per_eu in the tuning space

* Do not allocate tensor on device during kernel compilation step

* Add breakdown elapsed time

* Parallelize the post-processing step

* Parallelize the profile step with --ngpus

* Better timing info printout
This commit is contained in:
Lixun Zhang
2023-10-16 13:50:03 -05:00
committed by GitHub
parent 821e75a2b0
commit 1de859df32

View File

@@ -12,6 +12,7 @@ import triton.language as tl
from matmul_kernel import matmul_kernel
from datetime import datetime
import multiprocessing
def get_full_tuning_space():
@@ -26,6 +27,7 @@ def get_full_tuning_space():
# But keep this explicit so that we do not forget we may need to set it to
# other values in the future
num_stage_range = [1, 0]
waves_per_eu_range = [0,1,2,3,4]
for block_m in block_mn_range:
for block_n in block_mn_range:
@@ -34,7 +36,8 @@ def get_full_tuning_space():
for group_m in group_m_range:
for split_k in split_k_range:
for num_stages in num_stage_range:
configs.append({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m, 'SPLIT_K': split_k, 'num_warps': num_warps, 'num_stages': num_stages})
for waves_per_eu in waves_per_eu_range:
configs.append({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m, 'SPLIT_K': split_k, 'num_warps': num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu})
return configs
@@ -100,55 +103,76 @@ def read_config(config):
split_k = config.get('SPLIT_K')
num_warps = config.get('num_warps')
num_stages = config.get('num_stages')
return block_m, block_n, block_k, group_m, split_k, num_warps, num_stages
waves_per_eu = config.get('waves_per_eu')
return block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu
def gen_kernel_and_configStr_from_config(M, N, K, config):
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages = read_config(config)
configStr = f"M{M}_N{N}_K{K}_BM{block_m}_BN{block_n}_BK{block_k}_GM{group_m}_SK{split_k}_nW{num_warps}_nS{num_stages}"
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu = read_config(config)
configStr = f"M{M}_N{N}_K{K}_BM{block_m}_BN{block_n}_BK{block_k}_GM{group_m}_SK{split_k}_nW{num_warps}_nS{num_stages}_EU{waves_per_eu}"
matmul_def_str = f"""
def matmul_{configStr}(a, b, c):
M, K = a.shape
K, N = b.shape
def matmul_{configStr}(a, b, c, M, N, K, am, ak, bk, bn, cm, cn, warmup=False):
#M, K = a.shape
#K, N = b.shape
grid = triton.cdiv(M, {block_m}) * triton.cdiv(N, {block_n}), {split_k}
print(f'config: matmul_kernel_{configStr}')
matmul_kernel_{configStr}[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_SIZE_M = {block_m},
BLOCK_SIZE_N = {block_n},
BLOCK_SIZE_K = {block_k},
GROUP_SIZE_M = {group_m},
SPLIT_K = {split_k},
num_warps = {num_warps},
num_stages = {num_stages}
)
if warmup:
matmul_kernel_{configStr}.warmup(
torch.float16, torch.float16, torch.float16,
M, N, K,
am, ak, bk, bn, cm, cn,
BLOCK_SIZE_M = {block_m},
BLOCK_SIZE_N = {block_n},
BLOCK_SIZE_K = {block_k},
GROUP_SIZE_M = {group_m},
SPLIT_K = {split_k},
num_warps = {num_warps},
num_stages = {num_stages},
waves_per_eu = {waves_per_eu},
grid=(1,)
)
else:
matmul_kernel_{configStr}[grid](
a, b, c,
M, N, K,
am, ak, bk, bn, cm, cn,
BLOCK_SIZE_M = {block_m},
BLOCK_SIZE_N = {block_n},
BLOCK_SIZE_K = {block_k},
GROUP_SIZE_M = {group_m},
SPLIT_K = {split_k},
num_warps = {num_warps},
num_stages = {num_stages},
waves_per_eu = {waves_per_eu}
)
return c
def try_config_{configStr}(M, N, K, dtype):
a = torch.randn((M, K), device='cuda', dtype=dtype)
b = torch.randn((K, N), device='cuda', dtype=dtype)
c = torch.zeros((M, N), device=a.device, dtype=a.dtype)
def try_config_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, dtype):
#a = torch.randn((M, K), device='cuda', dtype=dtype)
#b = torch.randn((K, N), device='cuda', dtype=dtype)
#c = torch.zeros((M, N), device=a.device, dtype=a.dtype)
try:
matmul_{configStr}(a, b, c)
matmul_{configStr}(None, None, None, M, N, K, am, ak, bk, bn, cm, cn, True)
except Exception:
print(f'invalid config {configStr}')
"""
return configStr, matmul_def_str
## Open a file generated_kernelMNK.py and generate
## Open {ngpus} files
## generated_kernelMNK-0.py, generated_kernelMNK-1.py, ..., generated_kernelMNK-{ngpus-1}.py
## and generate
## 1. matmul kernels of all configs
## 2. wrapper function matmul to invoke all the generated kernels
## 3. Another wraper function try_config to invoke matmul function
## 4. test_gemm to invoke
## 4.1 run try_config in parallel
## 4.2 matmul in a loop of 10 iterations
def generate_kernel(M, N, K, configs):
f_kernel = open(f'generated_kernel{M}{N}{K}.py', 'w')
def generate_kernel(M, N, K, configs, ngpus):
filenames = []
for fi in range(ngpus):
filenames.append(f"generated_kernel{M}{N}{K}-{fi}.py")
f_kernel = [open(path, 'w') for path in filenames]
### write imports
import_str = """import torch
@@ -158,20 +182,24 @@ import argparse
import sys
import multiprocessing
"""
f_kernel.write(import_str + "\n")
for fi in range(ngpus):
f_kernel[fi].write(import_str + "\n")
### write definitions of matmul_kernel_xxx
### and matmul_xxx and try_config
with open("matmul_kernel.py") as file:
matmul_kernel_code = file.read();
idx = 0
for config in configs:
file_idx = idx % ngpus
configStr, matmul_def_str = gen_kernel_and_configStr_from_config(M, N, K, config)
## Copy the matmul_kernel with name replaced
matmul_kernel_config = matmul_kernel_code.replace("matmul_kernel", f"matmul_kernel_{configStr}")
matmul_kernel_config = matmul_kernel_config.replace("import triton.language as tl", "")
matmul_kernel_config = matmul_kernel_config.replace("import triton", "")
f_kernel.write(matmul_kernel_config + "\n\n")
f_kernel.write(matmul_def_str + "\n")
f_kernel[file_idx].write(matmul_kernel_config + "\n\n")
f_kernel[file_idx].write(matmul_def_str + "\n")
idx += 1
### write test_gemm
# pre string
@@ -180,25 +208,42 @@ import multiprocessing
a = torch.randn((M, K), device='cuda', dtype=dtype)
b = torch.randn((K, N), device='cuda', dtype=dtype)
c = torch.zeros((M, N), device=a.device, dtype=a.dtype)
task_args = (M, N, K, dtype)
task_args = (M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1), dtype)
if num_threads > 1:
"""
f_kernel.write(test_gemm_pre_str + "\n")
for fi in range(ngpus):
f_kernel[fi].write(test_gemm_pre_str + "\n")
# warm up call of all matmul functions in parallel
idx = 0
for config in configs:
configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config)
task_str = f" thread_pool.apply_async(try_config_{configStr}, args=task_args)\n"
f_kernel.write(task_str)
task_str = f" thread_pool.apply_async(try_config_{configStr}, args=task_args)\n"
f_kernel[idx % ngpus].write(task_str)
idx += 1
threadpool_str = """
thread_pool.close()
thread_pool.join()
else:"""
for fi in range(ngpus):
f_kernel[fi].write(threadpool_str)
# call all matmul_xxx functions
idx = 0
for config in configs:
configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config)
matmul_call_str = f"""
for i in range(10):
d = matmul_{configStr}(a, b, c)"""
f_kernel.write(matmul_call_str + "\n")
for i in range(10):
d = matmul_{configStr}(a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1))"""
f_kernel[idx % ngpus].write(matmul_call_str + "\n")
idx += 1
# post string
f_kernel.write(" return d\n")
for fi in range(ngpus):
f_kernel[fi].write(" return d\n")
### def main and call test_gemm
def_main_str = """
@@ -211,34 +256,69 @@ def main():
numThreads = args.n
"""
test_gemm_call_str = f'test_gemm({M}, {N}, {K}, torch.float16, numThreads)'
f_kernel.write(def_main_str)
f_kernel.write(test_gemm_call_str + "\n\n")
f_kernel.write("""if __name__ == '__main__':
sys.exit(main())""")
f_kernel.close()
for fi in range(ngpus):
f_kernel[fi].write(def_main_str)
f_kernel[fi].write(test_gemm_call_str + "\n\n")
f_kernel[fi].write("""if __name__ == '__main__':
sys.exit(main())""")
f_kernel[fi].close()
def extract_kernel_time(M, N, K, config, gpuid):
configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config)
parse_result_cmd = f'sed -n \'/matmul_kernel_{configStr}/p\' results-{gpuid}.csv | awk -F \',\' \'{{print $NF}}\' | tail -n1'
parsed_outputs = run_bash_command(parse_result_cmd)
return config, parsed_outputs
def tune_gemm_config(M, N, K, configs):
def profile_batch_kernels(M, N, K, gpuid):
os.environ['ROCR_VISIBLE_DEVICES'] = str(gpuid)
run_bash_command(f"rocprof --stats -o results-{gpuid}.csv python generated_kernel{M}{N}{K}-{gpuid}.py")
def tune_gemm_config(M, N, K, configs, verbose=False, num_threads=16, ngpus = 1):
## Generate kernel out of all configs
generate_kernel(M, N, K, configs)
generate_kernel(M, N, K, configs, ngpus)
## remove any compiled kernel in the cache
run_bash_command("rm -rf ~/.triton/cache")
## precompile the kernels in parallel
## TODO: parameterize numThreads at this level
run_bash_command(f"python generated_kernel{M}{N}{K}.py -n 16")
start_time = datetime.now()
for fi in range(ngpus):
run_bash_command(f"python generated_kernel{M}{N}{K}-{fi}.py -n 32")
compile_end = datetime.now()
compile_time = compile_end - start_time
if verbose:
print(f"compile time: {compile_time}")
## profile generated kernels
run_bash_command(f"rocprof --stats python generated_kernel{M}{N}{K}.py")
running = [multiprocessing.Process(target=profile_batch_kernels, args=(M,N,K,fi)) for fi in range(ngpus)]
for p in running:
p.start()
for p in running:
p.join()
profile_end = datetime.now()
profile_time = profile_end - compile_end
if verbose:
print(f"profile time: {profile_time}")
## post process results.csv to get the best config and minTime
## TODO: process the file in parallel
minTime = 1024 * 1024 * 1024
thread_pool = multiprocessing.Pool(processes=num_threads)
tasks = []
idx = 0
for config in configs:
configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config)
parse_result_cmd = f'sed -n \'/matmul_kernel_{configStr}/p\' results.csv | awk -F \',\' \'{{print $NF}}\' | tail -n1'
parsed_outputs = run_bash_command(parse_result_cmd)
file_idx = idx % ngpus
tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, file_idx))]
idx += 1
thread_pool.close()
thread_pool.join()
for task in tasks:
config, parsed_outputs = task.get()
if parsed_outputs:
min_us = int(parsed_outputs[0]) / 1000
if min_us < minTime:
@@ -247,10 +327,14 @@ def tune_gemm_config(M, N, K, configs):
else:
min_us = -1
print(f"invalid config: SIZE {M} {N} {K}: {config}")
return minTime, bestConfig
post_end = datetime.now()
post_time = post_end - profile_end
if verbose:
print(f"post procesing time: {post_time}")
return minTime, bestConfig, compile_time, profile_time, post_time
def matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages):
def matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
@@ -276,19 +360,20 @@ def matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_
SPLIT_K = split_k,
num_warps = num_warps,
num_stages = num_stages,
waves_per_eu = waves_per_eu,
)
return c
def test_correctness(M, N, K, config, verbose, datatype = torch.float16):
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages = read_config(config)
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu = read_config(config)
torch.manual_seed(0)
a = torch.randn((M, K), device='cuda', dtype=datatype)
b = torch.randn((K, N), device='cuda', dtype=datatype)
# Allocates output.
c = torch.zeros((M, N), device=a.device, dtype=a.dtype)
triton_output = matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages)
triton_output = matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu)
torch_output = torch.matmul(a, b)
#print(f"triton_output={triton_output}")
#print(f"torch_output={torch_output}")
@@ -322,11 +407,13 @@ def parse_args():
parser.add_argument("-m", type=int, default=0)
parser.add_argument("-n", type=int, default=0)
parser.add_argument("-k", type=int, default=0)
parser.add_argument("--ngpus", type=int, default=1, help='number of GPUs used in the profiling step')
parser.add_argument("--gemm_size_file", type=str, default="", help='yaml file to indicate matrix size')
parser.add_argument("--tuning_results_file", type=str, default=get_default_tuning_result_filename(), help='yaml file to store tuning results')
parser.add_argument("--keep", action='store_true', default=False, help='keep generated files')
parser.add_argument("--compare", action='store_true', default=False, help="Whether check result correctness")
parser.add_argument("--compare_wo_tuning", action='store_true', default=False, help="Whether check result correctness")
parser.add_argument("--time_breakdown", action='store_true', default=False, help="Show detailed time breakdown of each step during the tuning")
args = parser.parse_args()
return args
@@ -337,6 +424,7 @@ def main():
matrix_size_file = args.gemm_size_file
tuning_output_file = args.tuning_results_file
keepTmp = args.keep
ngpus = args.ngpus
mnks = []
## TODO: make it more robust to get user input
@@ -369,9 +457,11 @@ def main():
configs_full = get_full_tuning_space()
start_time = datetime.now()
print(f"Tuning starts at: {start_time}")
f_results = open(tuning_output_file, 'w')
for (M, N, K) in mnks:
start_local_time = datetime.now()
## Obtain a pruned tuning space according to gemm size
pruned_configs = prune_configs(M, N, K, configs_full)
@@ -379,7 +469,7 @@ def main():
print(f"{size_str} nConfigs: {len(pruned_configs)}", end=" ", flush=True)
## The main tuning funtion for one gemm size
minTime, bestConfig = tune_gemm_config(M, N, K, pruned_configs)
minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config(M, N, K, pruned_configs, ngpus = ngpus, verbose=args.time_breakdown)
## post processing the numbers
perf_tflops = lambda us: 2 * M * N * K * 1e-12 / (us * 1e-6)
@@ -401,9 +491,10 @@ def main():
## remove generated files if asked to
if not keepTmp:
os.remove(f"generated_kernel{M}{N}{K}.py")
for f in glob.glob("results.*"):
os.remove(f)
for fi in range(ngpus):
os.remove(f"generated_kernel{M}{N}{K}-{fi}.py")
for f in glob.glob(f"results-{fi}.*"):
os.remove(f)
## Check correctness if asked to
if args.compare:
@@ -412,11 +503,15 @@ def main():
else:
print("")
end_local_time = datetime.now()
print(f">>> Elapsed time: {end_local_time - start_local_time} = {compile_time} (compile) + {profile_time} (profile) + {post_time} (post processing)")
f_results.close()
end_time = datetime.now()
tuning_time = end_time - start_time
print(f"Tuning time (h:m:s): {tuning_time}")
print(f"Tuning ends at: {end_time}")
print(f"Total tuning time (h:m:s): {tuning_time}")
if __name__ == '__main__':