mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[Tuning] Gemm tuning v3 (#457)
* Add gemm tuning script v3 * Introduce --jobs to control the number of files to generate * Switch to trans convention used by Tensile * Rerun rocprof if it crashes * update README * Remove peak perf and efficiency
This commit is contained in:
@@ -31,7 +31,7 @@ python tune_gemm.py -m 16 -n 16 -k 16
|
||||
|
||||
3. Choose the file to store tuning results
|
||||
```bash
|
||||
python tune_gemm.py --gemm_size_file input_gemm_sizes.yaml --tuning_results_file output_tuning.yaml
|
||||
python tune_gemm.py --gemm_size_file input_gemm_sizes.yaml --o output_tuning.yaml
|
||||
```
|
||||
|
||||
4. Only check correctness given the tuning results
|
||||
@@ -50,7 +50,8 @@ Workflow of the tuning process
|
||||
- When split-k is not needed, i.e. both M and N are large, it must be 1
|
||||
- GROUP_M * BLOCK_SIZE_M must be smaller than M. Otherwise, GROUP_M must be 1
|
||||
- When BLOCK_SIZE_K = 128, neither BLOCK_SIZE_M or BLOCK_SIZE_N can be 128. Otherwise too much LDS will be required. **Needs further investigation**
|
||||
3. Open a file `generated_kernel{M}{N}{K}.py` and write the following into the file
|
||||
- Skip BLOCK_SIZE_M or BLOCK_SIZE_N if they are over 2 times larger than M or N.
|
||||
3. Open a file `generated_kernel{M}-{N}-{K}-{gpuid}.py` and write the following into the file
|
||||
1. For each config in the pruned space, generate a kernel with name `matmul_kernel_{configStr}`, where `configStr` contains the gemm size and the tuning parameters.
|
||||
2. Generate `matmul` function for each config in a similar way
|
||||
3. Generate `try_config` functions for each `matmul` function.
|
||||
@@ -62,13 +63,62 @@ Workflow of the tuning process
|
||||
5. Invoke `rocprof` on the generated script
|
||||
6. Post process `results.csv` by extract the execution time of the last instance of each kernel. Pick the best one, write to file, and return.
|
||||
|
||||
### Known issues
|
||||
On some node, I saw the following runtime error
|
||||
# GEMM Tuning Script v3
|
||||
|
||||
### API changes
|
||||
|
||||
- Input and output data types can be provided as `-dtype_a`, `-dtype_b`, and `-dtype_c`.
|
||||
The provided types must be one of ['fp32', 'fp16', 'bf16', 'fp8', 'bf8', 'int8'].
|
||||
- Row/col major-ness of operand a and b can be provided as `-col_a` and `-col_b`.
|
||||
If set, it means the corresponding operand is column major.
|
||||
The major-ness is considered as problem input.
|
||||
So they should be included in the input yaml file. However, in the yaml file, user should
|
||||
set `rowMajowA` and `rowMajorB` as shown in the example below.
|
||||
- `--benchmark` is used to control if the perf config in the input yaml file is used as the tuning space.
|
||||
- `--jobs` is used to control the number of .py files for generated kernels.
|
||||
Note that this can be different from `ngpus`. This usually means multiple kernel files
|
||||
will be profiled on each GPU.
|
||||
This is necessary to keep each file "small" in terms of execution time.
|
||||
|
||||
### Implementation changes
|
||||
- `gen_input` is used to generate matmul inputs.
|
||||
- Time measurement
|
||||
- In benchmark mode, the kernel is executed 1000 times.
|
||||
- In tuning mode, each kernel is executed 200 times. We cannot afford to larger runs since rocprof hangs if the session takes too long.
|
||||
- In both tuning and benchmark mode, kernel time is measured as the average execution time of the last 100 instances.
|
||||
- Added error recovery. This helps when rocprof crashes in multi-processing mode.
|
||||
|
||||
|
||||
### Example Usage
|
||||
|
||||
Let's say we have an input yaml file, named `gemm_input.yaml`, that contains the following configs
|
||||
```yaml
|
||||
- {'M': 4864, 'N': 4096, 'K': 8192, 'rowMajorA': 'T', 'rowMajorB': 'N'}
|
||||
- {'M': 8192, 'N': 8192, 'K': 8192, 'rowMajorA': 'T', 'rowMajorB': 'N'}
|
||||
```
|
||||
:0:rocdevice.cpp :2776: 7321835745146 us: 1401 : [tid:0x7fc930830700] Callback: Queue 0x7fc9b7200000 aborting with error : HSA_STATUS_ERROR_INVALID_ISA: The instruction set architecture is invalid. code: 0x100f
|
||||
1. Tuning with bf8 input types with gpu 4,5,6,7, and save output to `output.yaml`
|
||||
```bash
|
||||
python ./tune_gemm.py --gemm_size_file gemm_input.yaml -dtype_a bf8 -dtype_b bf8 --gpu_ids 4,5,6,7 --o output.yaml
|
||||
```
|
||||
It's hard to reproduce the error. **Needs further investigation**
|
||||
- https://github.com/ROCmSoftwarePlatform/frameworks-internal/issues/6011
|
||||
|
||||
2. Check the correctness of the tuned configs
|
||||
```bash
|
||||
python ./tune_gemm.py --gemm_size_file output.yaml -dtype_a bf8 -dtype_b bf8 --compare_wo_tuning
|
||||
```
|
||||
|
||||
3. Run benchmark of the tuned configs
|
||||
```bash
|
||||
python ./tune_gemm.py --gemm_size_file output.yaml -dtype_a bf8 -dtype_b bf8 --benchmark
|
||||
```
|
||||
|
||||
A sample output from `benchmark` looks like
|
||||
```bash
|
||||
Benchmarking gemm with bf8 inputs (peak tflops: 1298)
|
||||
trans M N K TFLOPS Efficiency
|
||||
NT 4864 4096 8192 841.22 65%
|
||||
NT 8192 8192 8192 745.31 57%
|
||||
```
|
||||
|
||||
|
||||
# One config running script
|
||||
|
||||
|
||||
@@ -33,14 +33,15 @@ def matmul_kernel(
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))
|
||||
a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
|
||||
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
acc_dtype = tl.float32 if c_ptr.type.element_ty != tl.int8 else tl.int32
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):
|
||||
a = tl.load(a_ptrs)
|
||||
b = tl.load(b_ptrs)
|
||||
accumulator += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
|
||||
c = accumulator.to(tl.float16)
|
||||
c = accumulator.to(c_ptr.type.element_ty)
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
|
||||
@@ -13,6 +13,7 @@ from matmul_kernel import matmul_kernel
|
||||
|
||||
from datetime import datetime
|
||||
import multiprocessing
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def get_full_tuning_space():
|
||||
@@ -22,12 +23,12 @@ def get_full_tuning_space():
|
||||
block_k_range = [16, 32, 64, 128, 256]
|
||||
split_k_range = [1, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24]
|
||||
num_warps_range = [1, 2, 4, 8]
|
||||
group_m_range = [1, 4, 8]
|
||||
group_m_range = [1, 4, 8, 16, 32]
|
||||
# For now we see better perf with num_stages=0 for all gemm configs we care
|
||||
# But keep this explicit so that we do not forget we may need to set it to
|
||||
# other values in the future
|
||||
num_stage_range = [1, 0]
|
||||
waves_per_eu_range = [0,1,2,3,4]
|
||||
num_stage_range = [0]
|
||||
waves_per_eu_range = [0]
|
||||
|
||||
for block_m in block_mn_range:
|
||||
for block_n in block_mn_range:
|
||||
@@ -42,7 +43,7 @@ def get_full_tuning_space():
|
||||
return configs
|
||||
|
||||
|
||||
def prune_configs(M, N, K, configs):
|
||||
def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b):
|
||||
pruned_configs = []
|
||||
|
||||
if M < 32 or N < 32:
|
||||
@@ -50,10 +51,16 @@ def prune_configs(M, N, K, configs):
|
||||
else:
|
||||
mfma = 32
|
||||
|
||||
# TODO (zhanglx): figure out the boundary between large and small gemms
|
||||
large_gemm = False
|
||||
if M >= 2048 and N >=2048:
|
||||
large_gemm = True
|
||||
|
||||
for config in configs:
|
||||
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
|
||||
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
|
||||
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
|
||||
num_warps = config.get("num_warps")
|
||||
if mfma == 4 and BLOCK_SIZE_K < 64:
|
||||
continue
|
||||
# some layouts could not work properly in case
|
||||
@@ -82,10 +89,21 @@ def prune_configs(M, N, K, configs):
|
||||
# skip large GROUP_M
|
||||
if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
|
||||
continue
|
||||
## out of shared memory resource
|
||||
LDS = BLOCK_SIZE_K * BLOCK_SIZE_M + BLOCK_SIZE_K * BLOCK_SIZE_N
|
||||
if LDS * 2 > 65536:
|
||||
# out of shared memory resource
|
||||
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
|
||||
LDS = BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
|
||||
if LDS > 65536:
|
||||
continue
|
||||
# Skip small block sizes and num_warps for large gemm
|
||||
# For fp8, we want to only use BLOCK_SIZE >= 128
|
||||
# For fp16, we want to only use BLOCK_SIZE >= 64
|
||||
if large_gemm:
|
||||
if BLOCK_SIZE_M < (128/elemBytes_a) or BLOCK_SIZE_N < (128/elemBytes_a):
|
||||
continue
|
||||
if BLOCK_SIZE_K < 64:
|
||||
continue
|
||||
if num_warps < 4:
|
||||
continue
|
||||
|
||||
pruned_configs.append(config)
|
||||
|
||||
@@ -95,15 +113,21 @@ def prune_configs(M, N, K, configs):
|
||||
def need_split_k(SIZE_M, SIZE_N, SIZE_K):
|
||||
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024
|
||||
|
||||
def run_bash_command_wrapper(commandstring, capture=True):
|
||||
try:
|
||||
run_bash_command(commandstring, capture)
|
||||
except subprocess.CalledProcessError as e:
|
||||
if not capture:
|
||||
print(f"running {commandstring} one more time")
|
||||
run_bash_command(commandstring, capture)
|
||||
|
||||
def run_bash_command(commandstring, capture=True):
|
||||
if capture:
|
||||
proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash', stdout = subprocess.PIPE)
|
||||
proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash', stdout=subprocess.PIPE)
|
||||
return proc.stdout.splitlines()
|
||||
proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash')
|
||||
return None
|
||||
|
||||
|
||||
def read_config(config):
|
||||
block_m = config.get('BLOCK_SIZE_M')
|
||||
block_n = config.get('BLOCK_SIZE_N')
|
||||
@@ -125,7 +149,7 @@ def matmul_{configStr}(a, b, c, M, N, K, am, ak, bk, bn, cm, cn, warmup=False):
|
||||
#M, K = a.shape
|
||||
#K, N = b.shape
|
||||
grid = triton.cdiv(M, {block_m}) * triton.cdiv(N, {block_n}), {split_k}
|
||||
print(f'config: matmul_kernel_{configStr}', flush=True)
|
||||
#print(f'config: matmul_kernel_{configStr}', flush=True)
|
||||
if warmup:
|
||||
matmul_kernel_{configStr}.warmup(
|
||||
torch.float16, torch.float16, torch.float16,
|
||||
@@ -158,10 +182,7 @@ def matmul_{configStr}(a, b, c, M, N, K, am, ak, bk, bn, cm, cn, warmup=False):
|
||||
)
|
||||
return c
|
||||
|
||||
def try_config_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, dtype):
|
||||
#a = torch.randn((M, K), device='cuda', dtype=dtype)
|
||||
#b = torch.randn((K, N), device='cuda', dtype=dtype)
|
||||
#c = torch.zeros((M, N), device=a.device, dtype=a.dtype)
|
||||
def try_config_{configStr}(M, N, K, am, ak, bk, bn, cm, cn):
|
||||
try:
|
||||
matmul_{configStr}(None, None, None, M, N, K, am, ak, bk, bn, cm, cn, True)
|
||||
return True
|
||||
@@ -176,42 +197,42 @@ def generated_kernel_name(M, N, K, gpu_id):
|
||||
return f"generated_kernel{M}-{N}-{K}-{gpu_id}.py"
|
||||
|
||||
|
||||
## Open {len(gpus)} files
|
||||
## generated_kernelM-N-K-{gpus[0]}.py, generated_kernelM-N-K-{gpus[1]}.py, ..., generated_kernelM-N-K-{gpus[-1]}.py
|
||||
## and generate
|
||||
## 1. matmul kernels of all configs
|
||||
## 2. wrapper function matmul to invoke all the generated kernels
|
||||
## 3. Another wraper function try_config to invoke matmul function
|
||||
## 4. test_gemm to invoke
|
||||
## 4.1 run try_config in parallel
|
||||
## 4.2 matmul in a loop of 10 iterations
|
||||
def generate_kernel(M, N, K, configs, gpus):
|
||||
# Open {len(gpus)} files
|
||||
# generated_kernelM-N-K-{gpus[0]}.py, generated_kernelM-N-K-{gpus[1]}.py, ..., generated_kernelM-N-K-{gpus[-1]}.py
|
||||
# and generate
|
||||
# 1. matmul kernels of all configs
|
||||
# 2. wrapper function matmul to invoke all the generated kernels
|
||||
# 3. Another wraper function try_config to invoke matmul function
|
||||
# 4. test_gemm to invoke
|
||||
# 4.1 run try_config in parallel
|
||||
# 4.2 matmul in a loop of 10 iterations
|
||||
def generate_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, configs, jobs, run_bench):
|
||||
filenames = []
|
||||
ngpus = len(gpus)
|
||||
for gpu_id in gpus:
|
||||
filenames.append(generated_kernel_name(M, N, K, gpu_id))
|
||||
for i in range(jobs):
|
||||
filenames.append(generated_kernel_name(M, N, K, i))
|
||||
f_kernel = [open(path, 'w') for path in filenames]
|
||||
|
||||
### write imports
|
||||
# write imports
|
||||
import_str = """import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import argparse
|
||||
import sys
|
||||
import multiprocessing
|
||||
from tune_gemm import gen_input
|
||||
"""
|
||||
for fi in range(ngpus):
|
||||
for fi in range(jobs):
|
||||
f_kernel[fi].write(import_str + "\n")
|
||||
|
||||
### write definitions of matmul_kernel_xxx
|
||||
### and matmul_xxx and try_config
|
||||
# write definitions of matmul_kernel_xxx
|
||||
# and matmul_xxx and try_config
|
||||
with open("matmul_kernel.py") as file:
|
||||
matmul_kernel_code = file.read()
|
||||
idx = 0
|
||||
for config in configs:
|
||||
file_idx = idx % ngpus
|
||||
file_idx = idx % jobs
|
||||
configStr, matmul_def_str = gen_kernel_and_configStr_from_config(M, N, K, config)
|
||||
## Copy the matmul_kernel with name replaced
|
||||
# Copy the matmul_kernel with name replaced
|
||||
matmul_kernel_config = matmul_kernel_code.replace("matmul_kernel", f"matmul_kernel_{configStr}")
|
||||
matmul_kernel_config = matmul_kernel_config.replace("import triton.language as tl", "")
|
||||
matmul_kernel_config = matmul_kernel_config.replace("import triton", "")
|
||||
@@ -219,23 +240,23 @@ import multiprocessing
|
||||
f_kernel[file_idx].write(matmul_def_str + "\n")
|
||||
idx += 1
|
||||
|
||||
### write test_gemm
|
||||
# write test_gemm
|
||||
# pre string
|
||||
test_gemm_pre_str = """def test_gemm(M, N, K, dtype, num_threads):
|
||||
test_gemm_pre_str = f"""def test_gemm(M, N, K, num_threads):
|
||||
thread_pool = multiprocessing.Pool(processes=num_threads)
|
||||
a = torch.randn((M, K), device='cuda', dtype=dtype)
|
||||
b = torch.randn((K, N), device='cuda', dtype=dtype)
|
||||
c = torch.zeros((M, N), device=a.device, dtype=a.dtype)
|
||||
a, a_fp16 = gen_input(M, K, '{dtype_a}', {col_a}, 1, device='cuda')
|
||||
b, b_fp16 = gen_input(K, N, '{dtype_b}', {col_b}, 2, device='cuda')
|
||||
c = torch.zeros((M, N), device=a.device, dtype={tl_to_torch_types[name_to_tl_types[dtype_c]]})
|
||||
task_args = (M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1), dtype)
|
||||
c.stride(0), c.stride(1))
|
||||
|
||||
if num_threads > 1:
|
||||
results = []
|
||||
config_names = []
|
||||
"""
|
||||
for fi in range(ngpus):
|
||||
for fi in range(jobs):
|
||||
f_kernel[fi].write(test_gemm_pre_str + "\n")
|
||||
|
||||
# warm up call of all matmul functions in parallel
|
||||
@@ -244,10 +265,10 @@ import multiprocessing
|
||||
configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config)
|
||||
task_str = f" results += [thread_pool.apply_async(try_config_{configStr}, args=task_args)]\n" + \
|
||||
f" config_names += ['{configStr}']\n"
|
||||
f_kernel[idx % ngpus].write(task_str)
|
||||
f_kernel[idx % jobs].write(task_str)
|
||||
idx += 1
|
||||
|
||||
for fi in range(ngpus):
|
||||
for fi in range(jobs):
|
||||
threadpool_str = """
|
||||
failed_configs = []
|
||||
for i in range(len(results)):
|
||||
@@ -266,23 +287,24 @@ import multiprocessing
|
||||
failed_configs = [cfg.strip() for cfg in f.readlines()]
|
||||
except Exception:
|
||||
failed_configs = []
|
||||
""".format(filename = filenames[fi])
|
||||
""".format(filename=filenames[fi])
|
||||
f_kernel[fi].write(threadpool_str)
|
||||
# call all matmul_xxx functions
|
||||
idx = 0
|
||||
runs = 1000 if run_bench else 200
|
||||
for config in configs:
|
||||
configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config)
|
||||
matmul_call_str = f"""
|
||||
if '{configStr}' not in failed_configs:
|
||||
for i in range(10):
|
||||
for i in range({runs}):
|
||||
d = matmul_{configStr}(a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1))"""
|
||||
f_kernel[idx % ngpus].write(matmul_call_str + "\n")
|
||||
f_kernel[idx % jobs].write(matmul_call_str + "\n")
|
||||
idx += 1
|
||||
# post string
|
||||
for fi in range(ngpus):
|
||||
for fi in range(jobs):
|
||||
f_kernel[fi].write(" return d\n")
|
||||
|
||||
### def main and call test_gemm
|
||||
# def main and call test_gemm
|
||||
def_main_str = """
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
@@ -292,44 +314,54 @@ def main():
|
||||
args = parser.parse_args()
|
||||
numThreads = args.n
|
||||
"""
|
||||
test_gemm_call_str = f'test_gemm({M}, {N}, {K}, torch.float16, numThreads)'
|
||||
for fi in range(ngpus):
|
||||
f_kernel[fi].write(def_main_str)
|
||||
f_kernel[fi].write(test_gemm_call_str + "\n\n")
|
||||
f_kernel[fi].write("""if __name__ == '__main__':
|
||||
test_gemm_call_str = f'test_gemm({M}, {N}, {K}, numThreads)'
|
||||
for fi in range(jobs):
|
||||
f_kernel[fi].write(def_main_str)
|
||||
f_kernel[fi].write(test_gemm_call_str + "\n\n")
|
||||
f_kernel[fi].write("""if __name__ == '__main__':
|
||||
sys.exit(main())""")
|
||||
f_kernel[fi].close()
|
||||
f_kernel[fi].close()
|
||||
|
||||
def extract_kernel_time(M, N, K, config, gpuid):
|
||||
|
||||
def extract_kernel_time(M, N, K, config, df):
|
||||
configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config)
|
||||
parse_result_cmd = f'sed -n \'/matmul_kernel_{configStr}/p\' results-{gpuid}.csv | awk -F \',\' \'{{print $NF}}\' | tail -n1'
|
||||
parsed_outputs = run_bash_command(parse_result_cmd)
|
||||
return config, parsed_outputs
|
||||
df = df[df['KernelName'].str.contains(configStr)]
|
||||
meanTime = df['DurationNs'].tail(100).mean()
|
||||
return config, meanTime
|
||||
|
||||
|
||||
def profile_batch_kernels(M, N, K, gpuid, verbose):
|
||||
def profile_batch_kernels(M, N, K, gpuid, gpus, jobs, verbose):
|
||||
ngpus = len(gpus)
|
||||
gpuIdx = gpus.index(gpuid)
|
||||
if gpuIdx + 1 > jobs:
|
||||
return
|
||||
os.environ['ROCR_VISIBLE_DEVICES'] = str(gpuid)
|
||||
run_bash_command(f"rocprof --stats -o results-{gpuid}.csv python {generated_kernel_name(M, N, K, gpuid)}", capture=(verbose < 2))
|
||||
jobId = gpuIdx
|
||||
while jobId < jobs:
|
||||
if verbose:
|
||||
print(f"profiling {generated_kernel_name(M, N, K, jobId)} on GPU {gpuid}")
|
||||
run_bash_command_wrapper(f"rocprof --stats -o results-{jobId}.csv python {generated_kernel_name(M, N, K, jobId)}", capture=(verbose < 2))
|
||||
jobId += ngpus
|
||||
|
||||
|
||||
def tune_gemm_config(M, N, K, configs, verbose=0, num_threads=16, gpus = [0]):
|
||||
## Generate kernel out of all configs
|
||||
generate_kernel(M, N, K, configs, gpus)
|
||||
def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, configs, run_bench, jobs, verbose=0, num_threads=16, gpus=[0]):
|
||||
# Generate kernel out of all configs
|
||||
generate_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, configs, jobs, run_bench)
|
||||
|
||||
## remove any compiled kernel in the cache
|
||||
# remove any compiled kernel in the cache
|
||||
run_bash_command("rm -rf ~/.triton/cache")
|
||||
|
||||
## precompile the kernels in parallel
|
||||
# precompile the kernels in parallel
|
||||
start_time = datetime.now()
|
||||
for gpu_id in gpus:
|
||||
run_bash_command(f"python {generated_kernel_name(M, N, K, gpu_id)} -n {num_threads}", capture=(verbose < 2))
|
||||
for i in range(jobs):
|
||||
run_bash_command(f"python {generated_kernel_name(M, N, K, i)} -n {num_threads}", capture=(verbose < 2))
|
||||
compile_end = datetime.now()
|
||||
compile_time = compile_end - start_time
|
||||
if verbose:
|
||||
print(f"compile time: {compile_time}", flush=True)
|
||||
|
||||
## profile generated kernels
|
||||
running = [multiprocessing.Process(target=profile_batch_kernels, args=(M,N,K,gpu_id,verbose)) for gpu_id in gpus]
|
||||
# profile generated kernels
|
||||
running = [multiprocessing.Process(target=profile_batch_kernels, args=(M, N, K, gpu_id, gpus, jobs, verbose)) for gpu_id in gpus]
|
||||
for p in running:
|
||||
p.start()
|
||||
for p in running:
|
||||
@@ -340,24 +372,24 @@ def tune_gemm_config(M, N, K, configs, verbose=0, num_threads=16, gpus = [0]):
|
||||
if verbose:
|
||||
print(f"profile time: {profile_time}", flush=True)
|
||||
|
||||
## post process results.csv to get the best config and minTime
|
||||
## TODO: process the file in parallel
|
||||
# post process results.csv to get the best config and minTime
|
||||
# TODO: process the file in parallel
|
||||
minTime = 1024 * 1024 * 1024
|
||||
thread_pool = multiprocessing.Pool(processes=num_threads)
|
||||
tasks = []
|
||||
idx = 0
|
||||
df_prof = [pd.read_csv(f"results-{i}.csv") for i in range(jobs)]
|
||||
for config in configs:
|
||||
file_idx = idx % len(gpus)
|
||||
gpu_id = gpus[file_idx]
|
||||
tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, gpu_id))]
|
||||
file_idx = idx % jobs
|
||||
tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, df_prof[file_idx]))]
|
||||
idx += 1
|
||||
thread_pool.close()
|
||||
thread_pool.join()
|
||||
|
||||
for task in tasks:
|
||||
config, parsed_outputs = task.get()
|
||||
if parsed_outputs:
|
||||
min_us = int(parsed_outputs[0]) / 1000
|
||||
config, myTime = task.get()
|
||||
if myTime:
|
||||
min_us = myTime / 1000
|
||||
if min_us < minTime:
|
||||
minTime = min_us
|
||||
bestConfig = config
|
||||
@@ -370,12 +402,44 @@ def tune_gemm_config(M, N, K, configs, verbose=0, num_threads=16, gpus = [0]):
|
||||
print(f"post procesing time: {post_time}", flush=True)
|
||||
return minTime, bestConfig, compile_time, profile_time, post_time
|
||||
|
||||
def gen_input(M, N, ty_name, needTrans, seed, device='cuda'):
|
||||
d_type = name_to_tl_types[ty_name]
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
@triton.jit
|
||||
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
input = tl.load(input_ptr + offsets, mask=mask)
|
||||
output = input
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
if needTrans:
|
||||
raw_data = torch.randn((N, M), dtype=torch.float32, device='cuda').T
|
||||
else:
|
||||
raw_data = torch.randn((M, N), dtype=torch.float32, device='cuda')
|
||||
if (d_type == tl.float8e4b8 and TORCH_HAS_FP8E4B8) or \
|
||||
(d_type == tl.float8e5b16 and TORCH_HAS_FP8E5B16) or not d_type.is_fp8():
|
||||
input = raw_data.to(tl_to_torch_types[d_type])
|
||||
input_f16 = input.to(torch.float16)
|
||||
else:
|
||||
f8_tensor = raw_data.to(torch.int8)
|
||||
# keep only two bits of exponent to avoid overflow
|
||||
f8_tensor = f8_tensor & 0b00111111
|
||||
input = triton.reinterpret(f8_tensor, d_type)
|
||||
input_f16 = torch.empty_like(f8_tensor, dtype=torch.float16)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
n_elements = raw_data.numel()
|
||||
copy_kernel[grid](input, input_f16, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
return input, input_f16
|
||||
|
||||
def matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu):
|
||||
# Check constraints.
|
||||
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
|
||||
assert a.is_contiguous(), "Matrix A must be contiguous"
|
||||
assert b.is_contiguous(), "Matrix B must be contiguous"
|
||||
#assert a.is_contiguous(), "Matrix A must be contiguous"
|
||||
#assert b.is_contiguous(), "Matrix B must be contiguous"
|
||||
M, K = a.shape
|
||||
K, N = b.shape
|
||||
# 1D launch kernel where each block gets its own program.
|
||||
@@ -390,35 +454,39 @@ def matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
BLOCK_SIZE_M = block_m,
|
||||
BLOCK_SIZE_N = block_n,
|
||||
BLOCK_SIZE_K = block_k,
|
||||
GROUP_SIZE_M = group_m,
|
||||
SPLIT_K = split_k,
|
||||
num_warps = num_warps,
|
||||
num_stages = num_stages,
|
||||
waves_per_eu = waves_per_eu,
|
||||
BLOCK_SIZE_M=block_m,
|
||||
BLOCK_SIZE_N=block_n,
|
||||
BLOCK_SIZE_K=block_k,
|
||||
GROUP_SIZE_M=group_m,
|
||||
SPLIT_K=split_k,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
waves_per_eu=waves_per_eu,
|
||||
)
|
||||
return c
|
||||
|
||||
|
||||
def test_correctness(M, N, K, config, verbose, datatype = torch.float16):
|
||||
def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, config, verbose):
|
||||
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu = read_config(config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
a = torch.randn((M, K), device='cuda', dtype=datatype)
|
||||
b = torch.randn((K, N), device='cuda', dtype=datatype)
|
||||
#a = torch.randn((M, K), device='cuda', dtype=datatype)
|
||||
#b = torch.randn((K, N), device='cuda', dtype=datatype)
|
||||
a, a_fp16 = gen_input(M, K, dtype_a, col_a, 1, device='cuda')
|
||||
b, b_fp16 = gen_input(K, N, dtype_b, col_b, 2, device='cuda')
|
||||
# Allocates output.
|
||||
c = torch.zeros((M, N), device=a.device, dtype=a.dtype)
|
||||
c = torch.zeros((M, N), device=a.device, dtype=tl_to_torch_types[name_to_tl_types[dtype_c]])
|
||||
triton_output = matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu)
|
||||
torch_output = torch.matmul(a, b)
|
||||
#print(f"triton_output={triton_output}")
|
||||
#print(f"torch_output={torch_output}")
|
||||
torch_output = torch.matmul(a_fp16, b_fp16)
|
||||
# print(f"triton_output={triton_output}")
|
||||
# print(f"torch_output={torch_output}")
|
||||
rtol = 0 if torch.version.hip is None else 1e-2
|
||||
row_a_str = 'N' if col_a else 'T'
|
||||
row_b_str = 'N' if col_b else 'T'
|
||||
size_str = ''
|
||||
if verbose:
|
||||
size_str = f'SIZE M: {M}, N: {N}, K: {K} '
|
||||
if torch.allclose(triton_output, torch_output, atol=1e-1, rtol=rtol):
|
||||
size_str = f'SIZE M: {M}, N: {N}, K: {K}, trans: {row_a_str}{row_b_str}'
|
||||
if torch.allclose(triton_output.to(torch.float16), torch_output, atol=1e-1, rtol=rtol):
|
||||
print(f'{size_str} Correct✅')
|
||||
else:
|
||||
print(f'{size_str} Incorrect❌')
|
||||
@@ -444,26 +512,94 @@ def parse_args():
|
||||
parser.add_argument("-m", type=int, default=0)
|
||||
parser.add_argument("-n", type=int, default=0)
|
||||
parser.add_argument("-k", type=int, default=0)
|
||||
parser.add_argument("-col_a", action='store_true', default=False, help='whether matrix a is column major')
|
||||
parser.add_argument("-col_b", action='store_true', default=False, help='whether matrix b is column major')
|
||||
parser.add_argument("-dtype_a", type=str, default='fp16', help="matrix a element data type")
|
||||
parser.add_argument("-dtype_b", type=str, default='fp16', help="matrix b element data type")
|
||||
parser.add_argument("-dtype_c", type=str, default='fp16', help="output element data type")
|
||||
parser.add_argument("--ngpus", type=int, default=0, help='number of GPUs used in the profiling step')
|
||||
parser.add_argument("--gpu_ids", type=lambda s: [int(id) for id in s.split(',')], default=[], help='list of gpu ids to use for tuning')
|
||||
parser.add_argument("--gemm_size_file", type=str, default="", help='yaml file to indicate matrix size')
|
||||
parser.add_argument("--tuning_results_file", type=str, default=get_default_tuning_result_filename(), help='yaml file to store tuning results')
|
||||
parser.add_argument("--o", type=str, default=get_default_tuning_result_filename(), help='yaml file to store tuning results')
|
||||
parser.add_argument("--keep", action='store_true', default=False, help='keep generated files')
|
||||
parser.add_argument("--compare", action='store_true', default=False, help="Whether check result correctness")
|
||||
parser.add_argument("--compare_wo_tuning", action='store_true', default=False, help="Whether check result correctness")
|
||||
parser.add_argument("--benchmark", action='store_true', default=False, help="Benchmark the given config")
|
||||
parser.add_argument("--time_breakdown", action='store_true', default=False, help="Show detailed time breakdown of each step during the tuning")
|
||||
parser.add_argument("--verbose", action='store_true', default=False, help="enables time_breakdown and additional logging messages")
|
||||
parser.add_argument("--num_threads", type=int, default=16, help="number of threads to use for kernel compilation and post processing")
|
||||
parser.add_argument("--jobs", type=int, default=16, help="number of generated files")
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
TORCH_HAS_FP8E5B16 = hasattr(torch, 'float8_e5m2fnuz')
|
||||
TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz')
|
||||
tl_to_torch_types = {
|
||||
tl.float16: torch.float16,
|
||||
tl.bfloat16: torch.bfloat16,
|
||||
tl.float32: torch.float32,
|
||||
tl.int8: torch.int8,
|
||||
tl.int32: torch.int32,
|
||||
}
|
||||
if TORCH_HAS_FP8E5B16:
|
||||
tl_to_torch_types[tl.float8e5b16] = torch.float8_e5m2fnuz
|
||||
if TORCH_HAS_FP8E4B8:
|
||||
tl_to_torch_types[tl.float8e4b8] = torch.float8_e4m3fnuz
|
||||
|
||||
name_to_tl_types = {
|
||||
'int8': tl.int8,
|
||||
'int32': tl.int32,
|
||||
'fp16': tl.float16,
|
||||
'fp32': tl.float32,
|
||||
'bf16': tl.bfloat16,
|
||||
'fp8': tl.float8e4b8,
|
||||
'bf8': tl.float8e5b16,
|
||||
}
|
||||
|
||||
def process_item(item):
|
||||
M = item['M']
|
||||
N = item['N']
|
||||
K = item['K']
|
||||
col_a = False if item['rowMajorA'] == 'T' else True
|
||||
col_b = False if item['rowMajorB'] == 'T' else True
|
||||
del item['M']
|
||||
del item['N']
|
||||
del item['K']
|
||||
del item['rowMajorA']
|
||||
del item['rowMajorB']
|
||||
return M, N, K, col_a, col_b, item
|
||||
|
||||
def type_name_to_bytes(ty_name):
|
||||
if '32' in ty_name:
|
||||
return 4
|
||||
if '16' in ty_name:
|
||||
return 2
|
||||
if '8' in ty_name:
|
||||
return 1
|
||||
else:
|
||||
print(f"Unrecognized input type name {ty_name}")
|
||||
sys.exit(1)
|
||||
|
||||
def format_output(unformatted):
|
||||
if unformatted < 0.0001:
|
||||
formatted = "{:.3e}".format(unformatted)
|
||||
elif unformatted > 1000:
|
||||
formatted = "{:.1f}".format(unformatted)
|
||||
else:
|
||||
formatted = "{:.2f}".format(unformatted)
|
||||
return formatted
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
matrix_size_file = args.gemm_size_file
|
||||
tuning_output_file = args.tuning_results_file
|
||||
tuning_output_file = args.o
|
||||
keepTmp = args.keep
|
||||
run_bench = args.benchmark
|
||||
jobs = args.jobs
|
||||
|
||||
# Get GPU ids
|
||||
ngpus = args.ngpus
|
||||
gpu_ids = args.gpu_ids
|
||||
if ngpus != 0 and gpu_ids:
|
||||
@@ -476,99 +612,122 @@ def main():
|
||||
if gpu_ids:
|
||||
gpus = gpu_ids
|
||||
|
||||
if run_bench:
|
||||
gpus = [gpus[0]]
|
||||
jobs = 1
|
||||
|
||||
# Get element type
|
||||
dtype_a = args.dtype_a
|
||||
dtype_b = args.dtype_b
|
||||
dtype_c = args.dtype_c
|
||||
if not dtype_a in name_to_tl_types or not dtype_b in name_to_tl_types or not dtype_c in name_to_tl_types:
|
||||
print(f"Unsupported dtype_a {args.dtype_a} or dtype_b {args.dtype_b} or dtype_c {args.dtype_c}")
|
||||
print("Supported types: ", list(name_to_tl_types.keys()))
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
mnks = []
|
||||
## TODO: make it more robust to get user input
|
||||
# TODO: make it more robust to get user input
|
||||
if matrix_size_file == "" or not os.path.isfile(matrix_size_file):
|
||||
M = args.m
|
||||
N = args.n
|
||||
K = args.k
|
||||
mnks = [(M, N, K)]
|
||||
col_a = args.col_a
|
||||
col_b = args.col_b
|
||||
mnks = [(M, N, K, col_a, col_b)]
|
||||
else:
|
||||
with open(matrix_size_file) as file:
|
||||
matrix_sizes = yaml.safe_load(file)
|
||||
for sizes in matrix_sizes:
|
||||
M = sizes['M']
|
||||
N = sizes['N']
|
||||
K = sizes['K']
|
||||
mnks.append((M, N, K))
|
||||
|
||||
## Check correctness from given configs
|
||||
if args.compare_wo_tuning:
|
||||
for item in matrix_sizes:
|
||||
M = item['M']
|
||||
N = item['N']
|
||||
K = item['K']
|
||||
del item['M']
|
||||
del item['N']
|
||||
del item['K']
|
||||
test_correctness(M, N, K, item, True)
|
||||
M, N, K, col_a, col_b, item = process_item(item)
|
||||
mnks.append((M, N, K, col_a, col_b, item))
|
||||
|
||||
# Check correctness from given configs
|
||||
if args.compare_wo_tuning:
|
||||
for (M, N, K, col_a, col_b, myConfig) in mnks:
|
||||
test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, item, True)
|
||||
return
|
||||
|
||||
configs_full = get_full_tuning_space()
|
||||
|
||||
start_time = datetime.now()
|
||||
print(f"Tuning starts at: {start_time}", flush=True)
|
||||
if run_bench:
|
||||
print(f"Benchmarking gemm with {dtype_a} inputs")
|
||||
print("trans M N K TFLOPS")
|
||||
else:
|
||||
print(f"Tuning starts at: {start_time}", flush=True)
|
||||
f_results = open(tuning_output_file, 'w')
|
||||
|
||||
f_results = open(tuning_output_file, 'w')
|
||||
for (M, N, K) in mnks:
|
||||
for (M, N, K, col_a, col_b, myConfig) in mnks:
|
||||
start_local_time = datetime.now()
|
||||
## Obtain a pruned tuning space according to gemm size
|
||||
pruned_configs = prune_configs(M, N, K, configs_full)
|
||||
# Obtain a pruned tuning space according to gemm size
|
||||
# If running benchmark, use the provided config
|
||||
pruned_configs = [myConfig] if run_bench else prune_configs(M, N, K, configs_full, type_name_to_bytes(dtype_a), type_name_to_bytes(dtype_b))
|
||||
|
||||
size_str = f'SIZE: {M} {N} {K}'
|
||||
print(f"{size_str} nConfigs: {len(pruned_configs)}", end=" ", flush=True)
|
||||
row_a_str = 'N' if col_a else 'T'
|
||||
row_b_str = 'N' if col_b else 'T'
|
||||
size_str = f'SIZE: {M} {N} {K} {row_a_str}{row_b_str}'
|
||||
if not run_bench:
|
||||
print(f"{size_str} nConfigs: {len(pruned_configs)}", end=" ", flush=True)
|
||||
|
||||
## The main tuning funtion for one gemm size
|
||||
# The main tuning funtion for one gemm size
|
||||
verbose_level = 0
|
||||
if args.time_breakdown:
|
||||
verbose_level = 1
|
||||
if args.verbose:
|
||||
verbose_level = 2
|
||||
minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config(M, N, K, pruned_configs, num_threads=args.num_threads, gpus = gpus, verbose=verbose_level)
|
||||
minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, pruned_configs, run_bench, jobs, num_threads=args.num_threads, gpus=gpus, verbose=verbose_level)
|
||||
|
||||
## post processing the numbers
|
||||
# post processing the numbers
|
||||
perf_tflops = lambda us: 2 * M * N * K * 1e-12 / (us * 1e-6)
|
||||
tri_tflops = perf_tflops(minTime)
|
||||
if tri_tflops < 0.0001:
|
||||
formatted_tflops = "{:.3e}".format(tri_tflops)
|
||||
else:
|
||||
formatted_tflops = "{:.2f}".format(tri_tflops)
|
||||
print(f'TFLOPS: {formatted_tflops} time(us): {minTime}', end=" ", flush=True)
|
||||
formatted_tflops = format_output(tri_tflops)
|
||||
minTime = format_output(minTime)
|
||||
if not run_bench:
|
||||
print(f'TFLOPS: {formatted_tflops} time(us): {minTime}', end=" ", flush=True)
|
||||
|
||||
bestConfig_compact_str, _ = gen_kernel_and_configStr_from_config(M, N, K, bestConfig)
|
||||
print(f'best_config: {bestConfig_compact_str}', end=" ", flush=True)
|
||||
if not run_bench:
|
||||
print(f'best_config: {bestConfig_compact_str}', end=" ", flush=True)
|
||||
|
||||
## write best config to tuning_results.yaml
|
||||
sizeDict = {'M': M, 'N': N, 'K': K}
|
||||
# write best config to tuning_results.yaml
|
||||
if run_bench:
|
||||
print(f"{row_a_str}{row_b_str} {M:5d} {N:5d} {K:5d} {formatted_tflops}")
|
||||
|
||||
sizeDict = {'M': M, 'N': N, 'K': K, 'rowMajorA': row_a_str, 'rowMajorB': row_b_str}
|
||||
sizeDict.update(bestConfig)
|
||||
f_results.write("- " + str(sizeDict) + " ")
|
||||
f_results.write(f'# TFLOPS: {formatted_tflops} time(us): {minTime:.2f}\n')
|
||||
if not run_bench:
|
||||
f_results.write("- " + str(sizeDict) + " ")
|
||||
f_results.write(f'# TFLOPS: {formatted_tflops} time(us): {minTime}\n')
|
||||
|
||||
## remove generated files if asked to
|
||||
# remove generated files if asked to
|
||||
if not keepTmp:
|
||||
for gpu_id in gpus:
|
||||
generated_script = generated_kernel_name(M, N, K, gpu_id)
|
||||
for i in range(jobs):
|
||||
generated_script = generated_kernel_name(M, N, K, i)
|
||||
os.remove(generated_script)
|
||||
os.remove(generated_script + ".failed_configs")
|
||||
for f in glob.glob(f"results-{gpu_id}.*"):
|
||||
for f in glob.glob(f"results-{i}.*"):
|
||||
os.remove(f)
|
||||
|
||||
## Check correctness if asked to
|
||||
# Check correctness if asked to
|
||||
if args.compare:
|
||||
print("correctness: ", end=" ", flush=True)
|
||||
test_correctness(M, N, K, bestConfig, False)
|
||||
else:
|
||||
test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, bestConfig, False)
|
||||
elif not run_bench:
|
||||
print("", flush=True)
|
||||
|
||||
end_local_time = datetime.now()
|
||||
print(f">>> Elapsed time: {end_local_time - start_local_time} = {compile_time} (compile) + {profile_time} (profile) + {post_time} (post processing)", flush=True)
|
||||
if not run_bench:
|
||||
print(f">>> Elapsed time: {end_local_time - start_local_time} = {compile_time} (compile) + {profile_time} (profile) + {post_time} (post processing)", flush=True)
|
||||
|
||||
f_results.close()
|
||||
if not run_bench:
|
||||
f_results.close()
|
||||
|
||||
end_time = datetime.now()
|
||||
tuning_time = end_time - start_time
|
||||
print(f"Tuning ends at: {end_time}")
|
||||
print(f"Total tuning time (h:m:s): {tuning_time}")
|
||||
if not run_bench:
|
||||
print(f"Tuning ends at: {end_time}")
|
||||
print(f"Total tuning time (h:m:s): {tuning_time}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user