Files
Lixun Zhang 8351f49fc7 [Tuning] Gemm tuning v3 (#457)
* Add gemm tuning script v3

* Introduce --jobs to control the number of files to generate

* Switch to trans convention used by Tensile

* Rerun rocprof if it crashes

* update README

* Remove peak perf and efficiency
2024-01-17 10:09:34 -06:00
..
2024-01-17 10:09:34 -06:00
2024-01-17 10:09:34 -06:00
2023-09-07 08:09:11 -05:00

GEMM tuning script v2

This is the v2 version of the gemm tuning script, which is based on @scxiao's v1 (https://github.com/ROCmSoftwarePlatform/triton/pull/309) and @alefimov-amd's thread pool https://github.com/ROCmSoftwarePlatform/triton/pull/310

Main features

  • rocprof is used to measure the time for kernels in the full tuning space
  • Each kernel is executed 10 times and the execution time of the last instance is used
  • All kernels are compiled in parallel
  • Two modes for correctness checking
    • During tuning, check correctness with the best perf_config for the current gemm size
    • Without tuning, check correctness based on the tuning results, which includes best perf_config for each gemm size
  • The process takes about 30 - 40 minutes for the full tuning space with ~15000 configs
  • Limitations
    • For now, only support fp16 as inputs. It should be trivial to extend to other types, but may require some work for mixed inputs

Usage

Go to the script dir

cd triton/scripts/amd/gemm/
  1. Tune gemm sizes given in a yaml file and check correctness on the way
python tune_gemm.py --gemm_size_file input_gemm_sizes.yaml --compare
  1. Tune a single gemm size
python tune_gemm.py -m 16 -n 16 -k 16
  1. Choose the file to store tuning results
python tune_gemm.py --gemm_size_file input_gemm_sizes.yaml --o output_tuning.yaml
  1. Only check correctness given the tuning results
python tune_gemm.py --gemm_size_file output_tuning.yaml --compare_wo_tuning

Note that the tuning results file are provided as the gemm_size_file in this scenario.

Overview of implementations

Workflow of the tuning process

  1. Generate the full tuning space. For now the ranges for each tuning parameter are hard-coded
  2. Prune the tuning space according to the current GEMM size and some rules
    • BLOCK_SIZE must be equal or larger than the mfma instruction size.
    • SPLIT_K * BLOCK_SIZE_K must divide K. Therefore, we do not need EVEN_K in the kernel.
    • When split-k is not needed, i.e. both M and N are large, it must be 1
    • GROUP_M * BLOCK_SIZE_M must be smaller than M. Otherwise, GROUP_M must be 1
    • When BLOCK_SIZE_K = 128, neither BLOCK_SIZE_M or BLOCK_SIZE_N can be 128. Otherwise too much LDS will be required. Needs further investigation
    • Skip BLOCK_SIZE_M or BLOCK_SIZE_N if they are over 2 times larger than M or N.
  3. Open a file generated_kernel{M}-{N}-{K}-{gpuid}.py and write the following into the file
    1. For each config in the pruned space, generate a kernel with name matmul_kernel_{configStr}, where configStr contains the gemm size and the tuning parameters.
    2. Generate matmul function for each config in a similar way
    3. Generate try_config functions for each matmul function.
    4. Generate test_gemm, which does
      1. Add all try_config functions in the thread_pool by thread_pool.apply_async(try_config). This is used to compile all kernels in parallel.
      2. Call each matmul function in a for loop of 10 iterations
    5. Generate main function
  4. Run the generated script with 16 workers. This will compile all kernels in parallel.
  5. Invoke rocprof on the generated script
  6. Post process results.csv by extract the execution time of the last instance of each kernel. Pick the best one, write to file, and return.

GEMM Tuning Script v3

API changes

  • Input and output data types can be provided as -dtype_a, -dtype_b, and -dtype_c. The provided types must be one of ['fp32', 'fp16', 'bf16', 'fp8', 'bf8', 'int8'].
  • Row/col major-ness of operand a and b can be provided as -col_a and -col_b. If set, it means the corresponding operand is column major. The major-ness is considered as problem input. So they should be included in the input yaml file. However, in the yaml file, user should set rowMajowA and rowMajorB as shown in the example below.
  • --benchmark is used to control if the perf config in the input yaml file is used as the tuning space.
  • --jobs is used to control the number of .py files for generated kernels. Note that this can be different from ngpus. This usually means multiple kernel files will be profiled on each GPU. This is necessary to keep each file "small" in terms of execution time.

Implementation changes

  • gen_input is used to generate matmul inputs.
  • Time measurement
    • In benchmark mode, the kernel is executed 1000 times.
    • In tuning mode, each kernel is executed 200 times. We cannot afford to larger runs since rocprof hangs if the session takes too long.
    • In both tuning and benchmark mode, kernel time is measured as the average execution time of the last 100 instances.
  • Added error recovery. This helps when rocprof crashes in multi-processing mode.

Example Usage

Let's say we have an input yaml file, named gemm_input.yaml, that contains the following configs

- {'M': 4864, 'N': 4096, 'K': 8192, 'rowMajorA': 'T', 'rowMajorB': 'N'}
- {'M': 8192, 'N': 8192, 'K': 8192, 'rowMajorA': 'T', 'rowMajorB': 'N'}
  1. Tuning with bf8 input types with gpu 4,5,6,7, and save output to output.yaml
python ./tune_gemm.py --gemm_size_file gemm_input.yaml -dtype_a bf8 -dtype_b bf8 --gpu_ids 4,5,6,7 --o output.yaml
  1. Check the correctness of the tuned configs
python ./tune_gemm.py --gemm_size_file output.yaml -dtype_a bf8 -dtype_b bf8 --compare_wo_tuning
  1. Run benchmark of the tuned configs
python ./tune_gemm.py --gemm_size_file output.yaml -dtype_a bf8 -dtype_b bf8 --benchmark

A sample output from benchmark looks like

Benchmarking gemm with bf8 inputs (peak tflops: 1298)
trans    M     N     K    TFLOPS  Efficiency
NT    4864  4096  8192    841.22         65%
NT    8192  8192  8192    745.31         57%

One config running script

one_config.py is a script that runs one given matmul config. It is an interface to tune_gemm.py functionality and could be used for triton debugging.

Usage

This script supports two methods to specify configuration parameters.

Variant 1: Separate command line attributes.

python one_config.py -m 256 -n 256 -k 256 --block_m 64 --block_n 64 --block_k 64 --group_m 1 --split_k 2 --num_warps 2 --num_stages 0 --waves_per_eu 0

Variant 2: one-line config description. This is how configs are printed by tune_gemm.py script

python one_config.py --config_str M16_N8_K128_BM64_BN64_BK64_GM1_SK2_nW2_nS0_EU0