mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Benchmark FA on 2 GCDs (#393)
This commit is contained in:
71
scripts/amd/benchmark_flash_attention.py
Normal file
71
scripts/amd/benchmark_flash_attention.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import argparse
|
||||
import sys
|
||||
import git
|
||||
|
||||
|
||||
git_repo = git.Repo('.', search_parent_directories=True)
|
||||
git_root = git_repo.git.rev_parse("--show-toplevel")
|
||||
sys.path.insert(0, git_root+'/python/perf-kernels')
|
||||
FA = __import__('06-fused-attention-fwd-transV')
|
||||
|
||||
attention = FA._attention.apply
|
||||
|
||||
import torch
|
||||
|
||||
def benchmark_FA(BATCH, H, N_CTX, D_HEAD, causal, rep, mode, dtype=torch.float16, device="cuda"):
|
||||
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
v = torch.randn((BATCH, H, D_HEAD, N_CTX), dtype=dtype, device="cuda", requires_grad=True)
|
||||
sm_scale = 1.3
|
||||
split_kernel = True
|
||||
if mode == "bwd":
|
||||
causal=True
|
||||
fn = lambda: attention(q, k, v, sm_scale)
|
||||
|
||||
o = fn()
|
||||
|
||||
if mode == "bwd":
|
||||
do = torch.randn_like(o)
|
||||
o.backward(do, retain_graph=True)
|
||||
|
||||
for i in range(rep):
|
||||
if mode == "bwd":
|
||||
o = fn()
|
||||
o.backward(do, retain_graph=True)
|
||||
if mode == "fwd":
|
||||
fn()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def main(args=None):
|
||||
if args is None:
|
||||
args = sys.argv[1:]
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="FA benchmarking",
|
||||
description="benchmark FA fwd and bwd with 2 GPUs",
|
||||
allow_abbrev=False,
|
||||
)
|
||||
|
||||
parser.add_argument("-bs", type=int, default=argparse.SUPPRESS)
|
||||
parser.add_argument("-nheads", type=int, default=argparse.SUPPRESS)
|
||||
parser.add_argument("-d", type=int, default=argparse.SUPPRESS)
|
||||
parser.add_argument("-seqlen", type=int, default=argparse.SUPPRESS)
|
||||
parser.add_argument("-rep", type=int, default=argparse.SUPPRESS)
|
||||
parser.add_argument("-mode", type=str, default=argparse.SUPPRESS)
|
||||
|
||||
parsed_args = parser.parse_args(args)
|
||||
|
||||
bs = parsed_args.bs
|
||||
nheads = parsed_args.nheads
|
||||
d = parsed_args.d
|
||||
seqlen = parsed_args.seqlen
|
||||
rep = parsed_args.rep
|
||||
mode = parsed_args.mode
|
||||
|
||||
benchmark_FA(bs, nheads, seqlen, d, False, rep, mode)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
70
scripts/amd/run_2gcd.sh
Executable file
70
scripts/amd/run_2gcd.sh
Executable file
@@ -0,0 +1,70 @@
|
||||
#! /bin/bash
|
||||
|
||||
|
||||
## A simple script to run two flash attention kernels
|
||||
## with batch2-nheads48-d64 on two GPUs in parallel
|
||||
## $1: mode, fwd or bwd
|
||||
|
||||
if [[ $# -eq 0 ]];then
|
||||
echo "Must specify mode, fwd or bwd"
|
||||
exit
|
||||
fi
|
||||
|
||||
TRITON_DIR=$(git rev-parse --show-toplevel)
|
||||
|
||||
BENCHMARK_DRIVER=${TRITON_DIR}/scripts/amd/benchmark_flash_attention.py
|
||||
|
||||
bs=2
|
||||
nheads=48
|
||||
mode=$1
|
||||
|
||||
declare -A repA
|
||||
|
||||
if [[ $mode == "fwd" ]];then
|
||||
repA[1024]=160000
|
||||
repA[2048]=80000
|
||||
repA[4096]=40000
|
||||
repA[8192]=20000
|
||||
repA[16384]=10000
|
||||
else
|
||||
repA[1024]=10000
|
||||
repA[2048]=10000
|
||||
repA[4096]=2500
|
||||
repA[8192]=600
|
||||
repA[16384]=100
|
||||
fi
|
||||
|
||||
for d in 128 64
|
||||
do
|
||||
echo "Benchmarking FA $mode kernel with D = $d on 2 GCDs"
|
||||
for seqlen in 1024 2048 4096 8192 16384
|
||||
do
|
||||
rep=${repA[$seqlen]}
|
||||
args="-bs $bs -nheads $nheads -d $d -seqlen $seqlen -mode $mode"
|
||||
|
||||
## pre-compile the kernel
|
||||
python ${BENCHMARK_DRIVER} $args -rep 1
|
||||
|
||||
start_time=$(date +%s.%3N)
|
||||
export ROCR_VISIBLE_DEVICES=0
|
||||
python ${BENCHMARK_DRIVER} $args -rep $rep &
|
||||
|
||||
export ROCR_VISIBLE_DEVICES=1
|
||||
python ${BENCHMARK_DRIVER} $args -rep $rep
|
||||
|
||||
wait
|
||||
end_time=$(date +%s.%3N)
|
||||
|
||||
# elapsed time with millisecond resolution
|
||||
# keep three digits after floating point.
|
||||
elapsed=$(echo "scale=3; $end_time - $start_time" | bc)
|
||||
# Convert second to tflops
|
||||
if [[ $mode == "fwd" ]];then
|
||||
tflops=$(echo "scale=2; 8*$seqlen*$seqlen*$bs*$nheads*$d*$rep/$elapsed/1000000000000" | bc)
|
||||
else
|
||||
tflops=$(echo "scale=2; 7*4*0.5*$seqlen*$seqlen*$bs*$nheads*$d*$rep/$elapsed/1000000000000" | bc)
|
||||
fi
|
||||
echo "$seqlen $tflops tflops $elapsed s"
|
||||
|
||||
done
|
||||
done
|
||||
Reference in New Issue
Block a user