[GEMM] [Tuning] Make tuning script more verbose (#420)

This PR adds:
- verbose tuning mode: printing std output of compilation and tuning calls
- collecting information about failed compilations
- print correctness check output with word
- split dimensions in generated scripts with "-"
- gpu_ids option to set particular gpus
This commit is contained in:
Alexander Efimov
2023-12-11 05:04:00 +01:00
committed by GitHub
parent e19b5fd6bc
commit 2be6ec771e

View File

@@ -87,9 +87,12 @@ def need_split_k(SIZE_M, SIZE_N, SIZE_K):
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024
def run_bash_command(commandstring):
proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash', stdout = subprocess.PIPE)
return proc.stdout.splitlines()
def run_bash_command(commandstring, capture=True):
if capture:
proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash', stdout = subprocess.PIPE)
return proc.stdout.splitlines()
proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash')
return None
def read_config(config):
@@ -113,7 +116,7 @@ def matmul_{configStr}(a, b, c, M, N, K, am, ak, bk, bn, cm, cn, warmup=False):
#M, K = a.shape
#K, N = b.shape
grid = triton.cdiv(M, {block_m}) * triton.cdiv(N, {block_n}), {split_k}
print(f'config: matmul_kernel_{configStr}')
print(f'config: matmul_kernel_{configStr}', flush=True)
if warmup:
matmul_kernel_{configStr}.warmup(
torch.float16, torch.float16, torch.float16,
@@ -129,6 +132,7 @@ def matmul_{configStr}(a, b, c, M, N, K, am, ak, bk, bn, cm, cn, warmup=False):
waves_per_eu = {waves_per_eu},
grid=(1,)
)
return None
else:
matmul_kernel_{configStr}[grid](
a, b, c,
@@ -143,7 +147,7 @@ def matmul_{configStr}(a, b, c, M, N, K, am, ak, bk, bn, cm, cn, warmup=False):
num_stages = {num_stages},
waves_per_eu = {waves_per_eu}
)
return c
return c
def try_config_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, dtype):
#a = torch.randn((M, K), device='cuda', dtype=dtype)
@@ -151,13 +155,20 @@ def try_config_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, dtype):
#c = torch.zeros((M, N), device=a.device, dtype=a.dtype)
try:
matmul_{configStr}(None, None, None, M, N, K, am, ak, bk, bn, cm, cn, True)
except Exception:
print(f'invalid config {configStr}')
return True
except Exception as e:
print(f'invalid config(compilation): {configStr}: ', e, flush=True)
return False
"""
return configStr, matmul_def_str
## Open {ngpus} files
## generated_kernelMNK-0.py, generated_kernelMNK-1.py, ..., generated_kernelMNK-{ngpus-1}.py
def generated_kernel_name(M, N, K, gpu_id):
return f"generated_kernel{M}-{N}-{K}-{gpu_id}.py"
## Open {len(gpus)} files
## generated_kernelM-N-K-{gpus[0]}.py, generated_kernelM-N-K-{gpus[1]}.py, ..., generated_kernelM-N-K-{gpus[-1]}.py
## and generate
## 1. matmul kernels of all configs
## 2. wrapper function matmul to invoke all the generated kernels
@@ -165,10 +176,11 @@ def try_config_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, dtype):
## 4. test_gemm to invoke
## 4.1 run try_config in parallel
## 4.2 matmul in a loop of 10 iterations
def generate_kernel(M, N, K, configs, ngpus):
def generate_kernel(M, N, K, configs, gpus):
filenames = []
for fi in range(ngpus):
filenames.append(f"generated_kernel{M}{N}{K}-{fi}.py")
ngpus = len(gpus)
for gpu_id in gpus:
filenames.append(generated_kernel_name(M, N, K, gpu_id))
f_kernel = [open(path, 'w') for path in filenames]
### write imports
@@ -185,7 +197,7 @@ import multiprocessing
### write definitions of matmul_kernel_xxx
### and matmul_xxx and try_config
with open("matmul_kernel.py") as file:
matmul_kernel_code = file.read();
matmul_kernel_code = file.read()
idx = 0
for config in configs:
file_idx = idx % ngpus
@@ -211,6 +223,8 @@ import multiprocessing
c.stride(0), c.stride(1), dtype)
if num_threads > 1:
results = []
config_names = []
"""
for fi in range(ngpus):
f_kernel[fi].write(test_gemm_pre_str + "\n")
@@ -219,23 +233,40 @@ import multiprocessing
idx = 0
for config in configs:
configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config)
task_str = f" thread_pool.apply_async(try_config_{configStr}, args=task_args)\n"
task_str = f" results += [thread_pool.apply_async(try_config_{configStr}, args=task_args)]\n" + \
f" config_names += ['{configStr}']\n"
f_kernel[idx % ngpus].write(task_str)
idx += 1
threadpool_str = """
for fi in range(ngpus):
threadpool_str = """
failed_configs = []
for i in range(len(results)):
results[i].wait()
res = results[i].get()
if not res:
failed_configs += [config_names[i]]
thread_pool.close()
thread_pool.join()
else:"""
for fi in range(ngpus):
with open("{filename}.failed_configs", "w") as f:
for cfg in failed_configs:
f.write(cfg + "\\n")
else:
try:
with open("{filename}.failed_configs", "r") as f:
failed_configs = [cfg.strip() for cfg in f.readlines()]
except Exception:
failed_configs = []
""".format(filename = filenames[fi])
f_kernel[fi].write(threadpool_str)
# call all matmul_xxx functions
idx = 0
for config in configs:
configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config)
matmul_call_str = f"""
for i in range(10):
d = matmul_{configStr}(a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1))"""
if '{configStr}' not in failed_configs:
for i in range(10):
d = matmul_{configStr}(a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1))"""
f_kernel[idx % ngpus].write(matmul_call_str + "\n")
idx += 1
# post string
@@ -267,30 +298,29 @@ def extract_kernel_time(M, N, K, config, gpuid):
return config, parsed_outputs
def profile_batch_kernels(M, N, K, gpuid):
def profile_batch_kernels(M, N, K, gpuid, verbose):
os.environ['ROCR_VISIBLE_DEVICES'] = str(gpuid)
run_bash_command(f"rocprof --stats -o results-{gpuid}.csv python generated_kernel{M}{N}{K}-{gpuid}.py")
run_bash_command(f"rocprof --stats -o results-{gpuid}.csv python {generated_kernel_name(M, N, K, gpuid)}", capture=(verbose < 2))
def tune_gemm_config(M, N, K, configs, verbose=False, num_threads=16, ngpus = 1):
def tune_gemm_config(M, N, K, configs, verbose=0, num_threads=16, gpus = [0]):
## Generate kernel out of all configs
generate_kernel(M, N, K, configs, ngpus)
generate_kernel(M, N, K, configs, gpus)
## remove any compiled kernel in the cache
run_bash_command("rm -rf ~/.triton/cache")
## precompile the kernels in parallel
## TODO: parameterize numThreads at this level
start_time = datetime.now()
for fi in range(ngpus):
run_bash_command(f"python generated_kernel{M}{N}{K}-{fi}.py -n 32")
for gpu_id in gpus:
run_bash_command(f"python {generated_kernel_name(M, N, K, gpu_id)} -n {num_threads}", capture=(verbose < 2))
compile_end = datetime.now()
compile_time = compile_end - start_time
if verbose:
print(f"compile time: {compile_time}")
print(f"compile time: {compile_time}", flush=True)
## profile generated kernels
running = [multiprocessing.Process(target=profile_batch_kernels, args=(M,N,K,fi)) for fi in range(ngpus)]
running = [multiprocessing.Process(target=profile_batch_kernels, args=(M,N,K,gpu_id,verbose)) for gpu_id in gpus]
for p in running:
p.start()
for p in running:
@@ -299,7 +329,7 @@ def tune_gemm_config(M, N, K, configs, verbose=False, num_threads=16, ngpus = 1)
profile_end = datetime.now()
profile_time = profile_end - compile_end
if verbose:
print(f"profile time: {profile_time}")
print(f"profile time: {profile_time}", flush=True)
## post process results.csv to get the best config and minTime
## TODO: process the file in parallel
@@ -308,8 +338,9 @@ def tune_gemm_config(M, N, K, configs, verbose=False, num_threads=16, ngpus = 1)
tasks = []
idx = 0
for config in configs:
file_idx = idx % ngpus
tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, file_idx))]
file_idx = idx % len(gpus)
gpu_id = gpus[file_idx]
tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, gpu_id))]
idx += 1
thread_pool.close()
thread_pool.join()
@@ -323,11 +354,11 @@ def tune_gemm_config(M, N, K, configs, verbose=False, num_threads=16, ngpus = 1)
bestConfig = config
else:
min_us = -1
print(f"invalid config: SIZE {M} {N} {K}: {config}")
print(f"invalid config(post processing): SIZE {M} {N} {K}: {config}", flush=True)
post_end = datetime.now()
post_time = post_end - profile_end
if verbose:
print(f"post procesing time: {post_time}")
print(f"post procesing time: {post_time}", flush=True)
return minTime, bestConfig, compile_time, profile_time, post_time
@@ -379,9 +410,9 @@ def test_correctness(M, N, K, config, verbose, datatype = torch.float16):
if verbose:
size_str = f'SIZE M: {M}, N: {N}, K: {K} '
if torch.allclose(triton_output, torch_output, atol=1e-1, rtol=rtol):
print(f'{size_str}')
print(f'{size_str} Correct')
else:
print(f'{size_str}')
print(f'{size_str} Incorrect')
def get_default_tuning_result_filename():
@@ -404,13 +435,16 @@ def parse_args():
parser.add_argument("-m", type=int, default=0)
parser.add_argument("-n", type=int, default=0)
parser.add_argument("-k", type=int, default=0)
parser.add_argument("--ngpus", type=int, default=1, help='number of GPUs used in the profiling step')
parser.add_argument("--ngpus", type=int, default=0, help='number of GPUs used in the profiling step')
parser.add_argument("--gpu_ids", type=lambda s: [int(id) for id in s.split(',')], default=[], help='list of gpu ids to use for tuning')
parser.add_argument("--gemm_size_file", type=str, default="", help='yaml file to indicate matrix size')
parser.add_argument("--tuning_results_file", type=str, default=get_default_tuning_result_filename(), help='yaml file to store tuning results')
parser.add_argument("--keep", action='store_true', default=False, help='keep generated files')
parser.add_argument("--compare", action='store_true', default=False, help="Whether check result correctness")
parser.add_argument("--compare_wo_tuning", action='store_true', default=False, help="Whether check result correctness")
parser.add_argument("--time_breakdown", action='store_true', default=False, help="Show detailed time breakdown of each step during the tuning")
parser.add_argument("--verbose", action='store_true', default=False, help="enables time_breakdown and additional logging messages")
parser.add_argument("--num_threads", type=int, default=16, help="number of threads to use for kernel compilation and post processing")
args = parser.parse_args()
return args
@@ -422,6 +456,16 @@ def main():
tuning_output_file = args.tuning_results_file
keepTmp = args.keep
ngpus = args.ngpus
gpu_ids = args.gpu_ids
if ngpus != 0 and gpu_ids:
print("--ngpus and --gpu_ids are mutually exclusive options")
return os.EX_USAGE
if ngpus == 0 and not gpu_ids:
ngpus = 1
if ngpus != 0:
gpus = range(ngpus)
if gpu_ids:
gpus = gpu_ids
mnks = []
## TODO: make it more robust to get user input
@@ -454,7 +498,7 @@ def main():
configs_full = get_full_tuning_space()
start_time = datetime.now()
print(f"Tuning starts at: {start_time}")
print(f"Tuning starts at: {start_time}", flush=True)
f_results = open(tuning_output_file, 'w')
for (M, N, K) in mnks:
@@ -466,7 +510,12 @@ def main():
print(f"{size_str} nConfigs: {len(pruned_configs)}", end=" ", flush=True)
## The main tuning funtion for one gemm size
minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config(M, N, K, pruned_configs, ngpus = ngpus, verbose=args.time_breakdown)
verbose_level = 0
if args.time_breakdown:
verbose_level = 1
if args.verbose:
verbose_level = 2
minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config(M, N, K, pruned_configs, num_threads=args.num_threads, gpus = gpus, verbose=verbose_level)
## post processing the numbers
perf_tflops = lambda us: 2 * M * N * K * 1e-12 / (us * 1e-6)
@@ -475,10 +524,10 @@ def main():
formatted_tflops = "{:.3e}".format(tri_tflops)
else:
formatted_tflops = "{:.2f}".format(tri_tflops)
print(f'TFLOPS: {formatted_tflops} time(us): {minTime}', end=" ")
print(f'TFLOPS: {formatted_tflops} time(us): {minTime}', end=" ", flush=True)
bestConfig_compact_str, _ = gen_kernel_and_configStr_from_config(M, N, K, bestConfig)
print(f'best_config: {bestConfig_compact_str}', end=" ")
print(f'best_config: {bestConfig_compact_str}', end=" ", flush=True)
## write best config to tuning_results.yaml
sizeDict = {'M': M, 'N': N, 'K': K}
@@ -488,20 +537,22 @@ def main():
## remove generated files if asked to
if not keepTmp:
for fi in range(ngpus):
os.remove(f"generated_kernel{M}{N}{K}-{fi}.py")
for f in glob.glob(f"results-{fi}.*"):
for gpu_id in gpus:
generated_script = generated_kernel_name(M, N, K, gpu_id)
os.remove(generated_script)
os.remove(generated_script + ".failed_configs")
for f in glob.glob(f"results-{gpu_id}.*"):
os.remove(f)
## Check correctness if asked to
if args.compare:
print("correctness: ", end=" ")
print("correctness: ", end=" ", flush=True)
test_correctness(M, N, K, bestConfig, False)
else:
print("")
print("", flush=True)
end_local_time = datetime.now()
print(f">>> Elapsed time: {end_local_time - start_local_time} = {compile_time} (compile) + {profile_time} (profile) + {post_time} (post processing)")
print(f">>> Elapsed time: {end_local_time - start_local_time} = {compile_time} (compile) + {profile_time} (profile) + {post_time} (post processing)", flush=True)
f_results.close()