mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[GEMM] [Tuning] Make tuning script more verbose (#420)
This PR adds: - verbose tuning mode: printing std output of compilation and tuning calls - collecting information about failed compilations - print correctness check output with word - split dimensions in generated scripts with "-" - gpu_ids option to set particular gpus
This commit is contained in:
@@ -87,9 +87,12 @@ def need_split_k(SIZE_M, SIZE_N, SIZE_K):
|
||||
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024
|
||||
|
||||
|
||||
def run_bash_command(commandstring):
|
||||
proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash', stdout = subprocess.PIPE)
|
||||
return proc.stdout.splitlines()
|
||||
def run_bash_command(commandstring, capture=True):
|
||||
if capture:
|
||||
proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash', stdout = subprocess.PIPE)
|
||||
return proc.stdout.splitlines()
|
||||
proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash')
|
||||
return None
|
||||
|
||||
|
||||
def read_config(config):
|
||||
@@ -113,7 +116,7 @@ def matmul_{configStr}(a, b, c, M, N, K, am, ak, bk, bn, cm, cn, warmup=False):
|
||||
#M, K = a.shape
|
||||
#K, N = b.shape
|
||||
grid = triton.cdiv(M, {block_m}) * triton.cdiv(N, {block_n}), {split_k}
|
||||
print(f'config: matmul_kernel_{configStr}')
|
||||
print(f'config: matmul_kernel_{configStr}', flush=True)
|
||||
if warmup:
|
||||
matmul_kernel_{configStr}.warmup(
|
||||
torch.float16, torch.float16, torch.float16,
|
||||
@@ -129,6 +132,7 @@ def matmul_{configStr}(a, b, c, M, N, K, am, ak, bk, bn, cm, cn, warmup=False):
|
||||
waves_per_eu = {waves_per_eu},
|
||||
grid=(1,)
|
||||
)
|
||||
return None
|
||||
else:
|
||||
matmul_kernel_{configStr}[grid](
|
||||
a, b, c,
|
||||
@@ -143,7 +147,7 @@ def matmul_{configStr}(a, b, c, M, N, K, am, ak, bk, bn, cm, cn, warmup=False):
|
||||
num_stages = {num_stages},
|
||||
waves_per_eu = {waves_per_eu}
|
||||
)
|
||||
return c
|
||||
return c
|
||||
|
||||
def try_config_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, dtype):
|
||||
#a = torch.randn((M, K), device='cuda', dtype=dtype)
|
||||
@@ -151,13 +155,20 @@ def try_config_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, dtype):
|
||||
#c = torch.zeros((M, N), device=a.device, dtype=a.dtype)
|
||||
try:
|
||||
matmul_{configStr}(None, None, None, M, N, K, am, ak, bk, bn, cm, cn, True)
|
||||
except Exception:
|
||||
print(f'invalid config {configStr}')
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f'invalid config(compilation): {configStr}: ', e, flush=True)
|
||||
return False
|
||||
"""
|
||||
return configStr, matmul_def_str
|
||||
|
||||
## Open {ngpus} files
|
||||
## generated_kernelMNK-0.py, generated_kernelMNK-1.py, ..., generated_kernelMNK-{ngpus-1}.py
|
||||
|
||||
def generated_kernel_name(M, N, K, gpu_id):
|
||||
return f"generated_kernel{M}-{N}-{K}-{gpu_id}.py"
|
||||
|
||||
|
||||
## Open {len(gpus)} files
|
||||
## generated_kernelM-N-K-{gpus[0]}.py, generated_kernelM-N-K-{gpus[1]}.py, ..., generated_kernelM-N-K-{gpus[-1]}.py
|
||||
## and generate
|
||||
## 1. matmul kernels of all configs
|
||||
## 2. wrapper function matmul to invoke all the generated kernels
|
||||
@@ -165,10 +176,11 @@ def try_config_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, dtype):
|
||||
## 4. test_gemm to invoke
|
||||
## 4.1 run try_config in parallel
|
||||
## 4.2 matmul in a loop of 10 iterations
|
||||
def generate_kernel(M, N, K, configs, ngpus):
|
||||
def generate_kernel(M, N, K, configs, gpus):
|
||||
filenames = []
|
||||
for fi in range(ngpus):
|
||||
filenames.append(f"generated_kernel{M}{N}{K}-{fi}.py")
|
||||
ngpus = len(gpus)
|
||||
for gpu_id in gpus:
|
||||
filenames.append(generated_kernel_name(M, N, K, gpu_id))
|
||||
f_kernel = [open(path, 'w') for path in filenames]
|
||||
|
||||
### write imports
|
||||
@@ -185,7 +197,7 @@ import multiprocessing
|
||||
### write definitions of matmul_kernel_xxx
|
||||
### and matmul_xxx and try_config
|
||||
with open("matmul_kernel.py") as file:
|
||||
matmul_kernel_code = file.read();
|
||||
matmul_kernel_code = file.read()
|
||||
idx = 0
|
||||
for config in configs:
|
||||
file_idx = idx % ngpus
|
||||
@@ -211,6 +223,8 @@ import multiprocessing
|
||||
c.stride(0), c.stride(1), dtype)
|
||||
|
||||
if num_threads > 1:
|
||||
results = []
|
||||
config_names = []
|
||||
"""
|
||||
for fi in range(ngpus):
|
||||
f_kernel[fi].write(test_gemm_pre_str + "\n")
|
||||
@@ -219,23 +233,40 @@ import multiprocessing
|
||||
idx = 0
|
||||
for config in configs:
|
||||
configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config)
|
||||
task_str = f" thread_pool.apply_async(try_config_{configStr}, args=task_args)\n"
|
||||
task_str = f" results += [thread_pool.apply_async(try_config_{configStr}, args=task_args)]\n" + \
|
||||
f" config_names += ['{configStr}']\n"
|
||||
f_kernel[idx % ngpus].write(task_str)
|
||||
idx += 1
|
||||
|
||||
threadpool_str = """
|
||||
for fi in range(ngpus):
|
||||
threadpool_str = """
|
||||
failed_configs = []
|
||||
for i in range(len(results)):
|
||||
results[i].wait()
|
||||
res = results[i].get()
|
||||
if not res:
|
||||
failed_configs += [config_names[i]]
|
||||
thread_pool.close()
|
||||
thread_pool.join()
|
||||
else:"""
|
||||
for fi in range(ngpus):
|
||||
with open("{filename}.failed_configs", "w") as f:
|
||||
for cfg in failed_configs:
|
||||
f.write(cfg + "\\n")
|
||||
else:
|
||||
try:
|
||||
with open("{filename}.failed_configs", "r") as f:
|
||||
failed_configs = [cfg.strip() for cfg in f.readlines()]
|
||||
except Exception:
|
||||
failed_configs = []
|
||||
""".format(filename = filenames[fi])
|
||||
f_kernel[fi].write(threadpool_str)
|
||||
# call all matmul_xxx functions
|
||||
idx = 0
|
||||
for config in configs:
|
||||
configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config)
|
||||
matmul_call_str = f"""
|
||||
for i in range(10):
|
||||
d = matmul_{configStr}(a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1))"""
|
||||
if '{configStr}' not in failed_configs:
|
||||
for i in range(10):
|
||||
d = matmul_{configStr}(a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1))"""
|
||||
f_kernel[idx % ngpus].write(matmul_call_str + "\n")
|
||||
idx += 1
|
||||
# post string
|
||||
@@ -267,30 +298,29 @@ def extract_kernel_time(M, N, K, config, gpuid):
|
||||
return config, parsed_outputs
|
||||
|
||||
|
||||
def profile_batch_kernels(M, N, K, gpuid):
|
||||
def profile_batch_kernels(M, N, K, gpuid, verbose):
|
||||
os.environ['ROCR_VISIBLE_DEVICES'] = str(gpuid)
|
||||
run_bash_command(f"rocprof --stats -o results-{gpuid}.csv python generated_kernel{M}{N}{K}-{gpuid}.py")
|
||||
run_bash_command(f"rocprof --stats -o results-{gpuid}.csv python {generated_kernel_name(M, N, K, gpuid)}", capture=(verbose < 2))
|
||||
|
||||
|
||||
def tune_gemm_config(M, N, K, configs, verbose=False, num_threads=16, ngpus = 1):
|
||||
def tune_gemm_config(M, N, K, configs, verbose=0, num_threads=16, gpus = [0]):
|
||||
## Generate kernel out of all configs
|
||||
generate_kernel(M, N, K, configs, ngpus)
|
||||
generate_kernel(M, N, K, configs, gpus)
|
||||
|
||||
## remove any compiled kernel in the cache
|
||||
run_bash_command("rm -rf ~/.triton/cache")
|
||||
|
||||
## precompile the kernels in parallel
|
||||
## TODO: parameterize numThreads at this level
|
||||
start_time = datetime.now()
|
||||
for fi in range(ngpus):
|
||||
run_bash_command(f"python generated_kernel{M}{N}{K}-{fi}.py -n 32")
|
||||
for gpu_id in gpus:
|
||||
run_bash_command(f"python {generated_kernel_name(M, N, K, gpu_id)} -n {num_threads}", capture=(verbose < 2))
|
||||
compile_end = datetime.now()
|
||||
compile_time = compile_end - start_time
|
||||
if verbose:
|
||||
print(f"compile time: {compile_time}")
|
||||
print(f"compile time: {compile_time}", flush=True)
|
||||
|
||||
## profile generated kernels
|
||||
running = [multiprocessing.Process(target=profile_batch_kernels, args=(M,N,K,fi)) for fi in range(ngpus)]
|
||||
running = [multiprocessing.Process(target=profile_batch_kernels, args=(M,N,K,gpu_id,verbose)) for gpu_id in gpus]
|
||||
for p in running:
|
||||
p.start()
|
||||
for p in running:
|
||||
@@ -299,7 +329,7 @@ def tune_gemm_config(M, N, K, configs, verbose=False, num_threads=16, ngpus = 1)
|
||||
profile_end = datetime.now()
|
||||
profile_time = profile_end - compile_end
|
||||
if verbose:
|
||||
print(f"profile time: {profile_time}")
|
||||
print(f"profile time: {profile_time}", flush=True)
|
||||
|
||||
## post process results.csv to get the best config and minTime
|
||||
## TODO: process the file in parallel
|
||||
@@ -308,8 +338,9 @@ def tune_gemm_config(M, N, K, configs, verbose=False, num_threads=16, ngpus = 1)
|
||||
tasks = []
|
||||
idx = 0
|
||||
for config in configs:
|
||||
file_idx = idx % ngpus
|
||||
tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, file_idx))]
|
||||
file_idx = idx % len(gpus)
|
||||
gpu_id = gpus[file_idx]
|
||||
tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, gpu_id))]
|
||||
idx += 1
|
||||
thread_pool.close()
|
||||
thread_pool.join()
|
||||
@@ -323,11 +354,11 @@ def tune_gemm_config(M, N, K, configs, verbose=False, num_threads=16, ngpus = 1)
|
||||
bestConfig = config
|
||||
else:
|
||||
min_us = -1
|
||||
print(f"invalid config: SIZE {M} {N} {K}: {config}")
|
||||
print(f"invalid config(post processing): SIZE {M} {N} {K}: {config}", flush=True)
|
||||
post_end = datetime.now()
|
||||
post_time = post_end - profile_end
|
||||
if verbose:
|
||||
print(f"post procesing time: {post_time}")
|
||||
print(f"post procesing time: {post_time}", flush=True)
|
||||
return minTime, bestConfig, compile_time, profile_time, post_time
|
||||
|
||||
|
||||
@@ -379,9 +410,9 @@ def test_correctness(M, N, K, config, verbose, datatype = torch.float16):
|
||||
if verbose:
|
||||
size_str = f'SIZE M: {M}, N: {N}, K: {K} '
|
||||
if torch.allclose(triton_output, torch_output, atol=1e-1, rtol=rtol):
|
||||
print(f'{size_str}✅')
|
||||
print(f'{size_str} Correct✅')
|
||||
else:
|
||||
print(f'{size_str}❌')
|
||||
print(f'{size_str} Incorrect❌')
|
||||
|
||||
|
||||
def get_default_tuning_result_filename():
|
||||
@@ -404,13 +435,16 @@ def parse_args():
|
||||
parser.add_argument("-m", type=int, default=0)
|
||||
parser.add_argument("-n", type=int, default=0)
|
||||
parser.add_argument("-k", type=int, default=0)
|
||||
parser.add_argument("--ngpus", type=int, default=1, help='number of GPUs used in the profiling step')
|
||||
parser.add_argument("--ngpus", type=int, default=0, help='number of GPUs used in the profiling step')
|
||||
parser.add_argument("--gpu_ids", type=lambda s: [int(id) for id in s.split(',')], default=[], help='list of gpu ids to use for tuning')
|
||||
parser.add_argument("--gemm_size_file", type=str, default="", help='yaml file to indicate matrix size')
|
||||
parser.add_argument("--tuning_results_file", type=str, default=get_default_tuning_result_filename(), help='yaml file to store tuning results')
|
||||
parser.add_argument("--keep", action='store_true', default=False, help='keep generated files')
|
||||
parser.add_argument("--compare", action='store_true', default=False, help="Whether check result correctness")
|
||||
parser.add_argument("--compare_wo_tuning", action='store_true', default=False, help="Whether check result correctness")
|
||||
parser.add_argument("--time_breakdown", action='store_true', default=False, help="Show detailed time breakdown of each step during the tuning")
|
||||
parser.add_argument("--verbose", action='store_true', default=False, help="enables time_breakdown and additional logging messages")
|
||||
parser.add_argument("--num_threads", type=int, default=16, help="number of threads to use for kernel compilation and post processing")
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
@@ -422,6 +456,16 @@ def main():
|
||||
tuning_output_file = args.tuning_results_file
|
||||
keepTmp = args.keep
|
||||
ngpus = args.ngpus
|
||||
gpu_ids = args.gpu_ids
|
||||
if ngpus != 0 and gpu_ids:
|
||||
print("--ngpus and --gpu_ids are mutually exclusive options")
|
||||
return os.EX_USAGE
|
||||
if ngpus == 0 and not gpu_ids:
|
||||
ngpus = 1
|
||||
if ngpus != 0:
|
||||
gpus = range(ngpus)
|
||||
if gpu_ids:
|
||||
gpus = gpu_ids
|
||||
|
||||
mnks = []
|
||||
## TODO: make it more robust to get user input
|
||||
@@ -454,7 +498,7 @@ def main():
|
||||
configs_full = get_full_tuning_space()
|
||||
|
||||
start_time = datetime.now()
|
||||
print(f"Tuning starts at: {start_time}")
|
||||
print(f"Tuning starts at: {start_time}", flush=True)
|
||||
|
||||
f_results = open(tuning_output_file, 'w')
|
||||
for (M, N, K) in mnks:
|
||||
@@ -466,7 +510,12 @@ def main():
|
||||
print(f"{size_str} nConfigs: {len(pruned_configs)}", end=" ", flush=True)
|
||||
|
||||
## The main tuning funtion for one gemm size
|
||||
minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config(M, N, K, pruned_configs, ngpus = ngpus, verbose=args.time_breakdown)
|
||||
verbose_level = 0
|
||||
if args.time_breakdown:
|
||||
verbose_level = 1
|
||||
if args.verbose:
|
||||
verbose_level = 2
|
||||
minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config(M, N, K, pruned_configs, num_threads=args.num_threads, gpus = gpus, verbose=verbose_level)
|
||||
|
||||
## post processing the numbers
|
||||
perf_tflops = lambda us: 2 * M * N * K * 1e-12 / (us * 1e-6)
|
||||
@@ -475,10 +524,10 @@ def main():
|
||||
formatted_tflops = "{:.3e}".format(tri_tflops)
|
||||
else:
|
||||
formatted_tflops = "{:.2f}".format(tri_tflops)
|
||||
print(f'TFLOPS: {formatted_tflops} time(us): {minTime}', end=" ")
|
||||
print(f'TFLOPS: {formatted_tflops} time(us): {minTime}', end=" ", flush=True)
|
||||
|
||||
bestConfig_compact_str, _ = gen_kernel_and_configStr_from_config(M, N, K, bestConfig)
|
||||
print(f'best_config: {bestConfig_compact_str}', end=" ")
|
||||
print(f'best_config: {bestConfig_compact_str}', end=" ", flush=True)
|
||||
|
||||
## write best config to tuning_results.yaml
|
||||
sizeDict = {'M': M, 'N': N, 'K': K}
|
||||
@@ -488,20 +537,22 @@ def main():
|
||||
|
||||
## remove generated files if asked to
|
||||
if not keepTmp:
|
||||
for fi in range(ngpus):
|
||||
os.remove(f"generated_kernel{M}{N}{K}-{fi}.py")
|
||||
for f in glob.glob(f"results-{fi}.*"):
|
||||
for gpu_id in gpus:
|
||||
generated_script = generated_kernel_name(M, N, K, gpu_id)
|
||||
os.remove(generated_script)
|
||||
os.remove(generated_script + ".failed_configs")
|
||||
for f in glob.glob(f"results-{gpu_id}.*"):
|
||||
os.remove(f)
|
||||
|
||||
## Check correctness if asked to
|
||||
if args.compare:
|
||||
print("correctness: ", end=" ")
|
||||
print("correctness: ", end=" ", flush=True)
|
||||
test_correctness(M, N, K, bestConfig, False)
|
||||
else:
|
||||
print("")
|
||||
print("", flush=True)
|
||||
|
||||
end_local_time = datetime.now()
|
||||
print(f">>> Elapsed time: {end_local_time - start_local_time} = {compile_time} (compile) + {profile_time} (profile) + {post_time} (post processing)")
|
||||
print(f">>> Elapsed time: {end_local_time - start_local_time} = {compile_time} (compile) + {profile_time} (profile) + {post_time} (post processing)", flush=True)
|
||||
|
||||
f_results.close()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user