Fix merge conflicts

This commit is contained in:
Jason Furmanek
2023-09-01 04:01:32 +00:00
parent 3eaeb89d18
commit df5c263a19
28 changed files with 127 additions and 1235 deletions

View File

@@ -1052,126 +1052,108 @@ void init_triton_ir(py::module &&m) {
.def("create_shl",
[](TritonOpBuilder &self, mlir::Value &lhs,
mlir::Value &rhs) -> mlir::Value {
<<<<<<< HEAD
auto loc = self.getUnknownLoc();
#ifdef USE_ROCM
mlir::Type elementType = getElementTypeOrSelf(lhs.getType());
unsigned typeWidth = elementType.getIntOrFloatBitWidth();
auto constValue = self.create<mlir::arith::ConstantIntOp>(
loc, typeWidth, elementType);
typeWidth, elementType);
auto zeroConst =
self.create<mlir::arith::ConstantIntOp>(loc, 0, elementType);
self.create<mlir::arith::ConstantIntOp>(0, elementType);
if (lhs.getType().isIntOrIndex()) {
auto cmpValue = self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::ult, rhs, constValue);
auto shiftValue = self.create<mlir::arith::ShLIOp>(loc, lhs, rhs);
mlir::arith::CmpIPredicate::ult, rhs, constValue);
auto shiftValue = self.create<mlir::arith::ShLIOp>(lhs, rhs);
return mlir::Value(self.create<mlir::arith::SelectOp>(
loc, cmpValue, shiftValue, zeroConst));
cmpValue, shiftValue, zeroConst));
} else {
auto splatValue = self.create<mlir::tensor::SplatOp>(
loc, lhs.getType(), constValue);
lhs.getType(), constValue);
auto zeroValue = self.create<mlir::tensor::SplatOp>(
loc, lhs.getType(), zeroConst);
lhs.getType(), zeroConst);
auto cmpValue = self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::ult, rhs, splatValue);
auto shiftValue = self.create<mlir::arith::ShLIOp>(loc, lhs, rhs);
mlir::arith::CmpIPredicate::ult, rhs, splatValue);
auto shiftValue = self.create<mlir::arith::ShLIOp>(lhs, rhs);
return mlir::Value(self.create<mlir::arith::SelectOp>(
loc, cmpValue, shiftValue, zeroValue));
cmpValue, shiftValue, zeroValue));
}
#else
return mlir::Value(
self.create<mlir::arith::ShLIOp>(loc, lhs, rhs));
#endif
=======
return mlir::Value(self.create<mlir::arith::ShLIOp>(lhs, rhs));
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
#endif
})
.def("create_lshr",
[](TritonOpBuilder &self, mlir::Value &lhs,
mlir::Value &rhs) -> mlir::Value {
<<<<<<< HEAD
auto loc = self.getUnknownLoc();
#ifdef USE_ROCM
mlir::Type elementType = getElementTypeOrSelf(lhs.getType());
unsigned typeWidth = elementType.getIntOrFloatBitWidth();
auto constValue = self.create<mlir::arith::ConstantIntOp>(
loc, typeWidth, elementType);
typeWidth, elementType);
auto zeroConst =
self.create<mlir::arith::ConstantIntOp>(loc, 0, elementType);
self.create<mlir::arith::ConstantIntOp>(0, elementType);
if (lhs.getType().isIntOrIndex()) {
auto cmpValue = self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::ult, rhs, constValue);
auto shiftValue = self.create<mlir::arith::ShRUIOp>(loc, lhs, rhs);
mlir::arith::CmpIPredicate::ult, rhs, constValue);
auto shiftValue = self.create<mlir::arith::ShRUIOp>(lhs, rhs);
return mlir::Value(self.create<mlir::arith::SelectOp>(
loc, cmpValue, shiftValue, zeroConst));
cmpValue, shiftValue, zeroConst));
} else {
auto splatValue = self.create<mlir::tensor::SplatOp>(
loc, lhs.getType(), constValue);
lhs.getType(), constValue);
auto zeroValue = self.create<mlir::tensor::SplatOp>(
loc, lhs.getType(), zeroConst);
lhs.getType(), zeroConst);
auto cmpValue = self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::ult, rhs, splatValue);
auto shiftValue = self.create<mlir::arith::ShRUIOp>(loc, lhs, rhs);
mlir::arith::CmpIPredicate::ult, rhs, splatValue);
auto shiftValue = self.create<mlir::arith::ShRUIOp>(lhs, rhs);
return mlir::Value(self.create<mlir::arith::SelectOp>(
loc, cmpValue, shiftValue, zeroValue));
cmpValue, shiftValue, zeroValue));
}
#else
return mlir::Value(
self.create<mlir::arith::ShRUIOp>(loc, lhs, rhs));
#endif
=======
return mlir::Value(self.create<mlir::arith::ShRUIOp>(lhs, rhs));
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
#endif
})
.def("create_ashr",
[](TritonOpBuilder &self, mlir::Value &lhs,
mlir::Value &rhs) -> mlir::Value {
<<<<<<< HEAD
auto loc = self.getUnknownLoc();
#ifdef USE_ROCM
mlir::Type elementType = getElementTypeOrSelf(lhs.getType());
unsigned typeWidth = elementType.getIntOrFloatBitWidth();
auto constValue = self.create<mlir::arith::ConstantIntOp>(
loc, typeWidth, elementType);
typeWidth, elementType);
auto zeroConst =
self.create<mlir::arith::ConstantIntOp>(loc, 0, elementType);
self.create<mlir::arith::ConstantIntOp>(0, elementType);
uint64_t ones_val = 0xFFFFFFFFFFFFFFFF;
auto onesConst =
self.create<mlir::arith::ConstantIntOp>(loc, ones_val, elementType);
self.create<mlir::arith::ConstantIntOp>(ones_val, elementType);
if (lhs.getType().isIntOrIndex()) {
auto negativeCmpValue = self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::slt, lhs, zeroConst);
mlir::arith::CmpIPredicate::slt, lhs, zeroConst);
auto otherValue = mlir::Value(self.create<mlir::arith::SelectOp>(
loc, negativeCmpValue, onesConst, zeroConst));
negativeCmpValue, onesConst, zeroConst));
auto cmpValue = self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::ult, rhs, constValue);
auto shiftValue = self.create<mlir::arith::ShRSIOp>(loc, lhs, rhs);
mlir::arith::CmpIPredicate::ult, rhs, constValue);
auto shiftValue = self.create<mlir::arith::ShRSIOp>(lhs, rhs);
return mlir::Value(self.create<mlir::arith::SelectOp>(
loc, cmpValue, shiftValue, otherValue));
cmpValue, shiftValue, otherValue));
} else {
auto splatValue = self.create<mlir::tensor::SplatOp>(
loc, lhs.getType(), constValue);
lhs.getType(), constValue);
auto zeroValue = self.create<mlir::tensor::SplatOp>(
loc, lhs.getType(), zeroConst);
lhs.getType(), zeroConst);
auto onesValue = self.create<mlir::tensor::SplatOp>(
loc, lhs.getType(), onesConst);
lhs.getType(), onesConst);
auto negativeCmpValue = self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::slt, lhs, zeroValue);
mlir::arith::CmpIPredicate::slt, lhs, zeroValue);
auto otherValue = mlir::Value(self.create<mlir::arith::SelectOp>(
loc, negativeCmpValue, onesValue, zeroValue));
negativeCmpValue, onesValue, zeroValue));
auto cmpValue = self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::ult, rhs, splatValue);
auto shiftValue = self.create<mlir::arith::ShRSIOp>(loc, lhs, rhs);
mlir::arith::CmpIPredicate::ult, rhs, splatValue);
auto shiftValue = self.create<mlir::arith::ShRSIOp>(lhs, rhs);
return mlir::Value(self.create<mlir::arith::SelectOp>(
loc, cmpValue, shiftValue, otherValue));
cmpValue, shiftValue, otherValue));
}
#else
return mlir::Value(
self.create<mlir::arith::ShRSIOp>(loc, lhs, rhs));
#endif
=======
return mlir::Value(self.create<mlir::arith::ShRSIOp>(lhs, rhs));
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
#endif
})
// AddPtr (similar to GEP)
.def("create_addptr",

View File

@@ -14,14 +14,9 @@ from typing import Any, Tuple
from .._C.libtriton.triton import (add_external_libs, compile_ptx_to_cubin,
get_shared_memory_size, ir,
translate_llvmir_to_hsaco, translate_llvmir_to_ptx,
<<<<<<< HEAD
translate_triton_gpu_to_llvmir, get_arch_info,
get_warp_size)
from ..common.backend import get_backend
=======
translate_triton_gpu_to_llvmir)
from ..common.backend import get_backend, path_to_ptxas
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
# from ..runtime import driver, jit, JITFunction
# TODO: runtime.errors
from ..runtime.autotuner import OutOfResources

View File

@@ -294,15 +294,12 @@ class _attention(torch.autograd.Function):
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
raise RuntimeError("Flash attention currently only supported for compute capability >= 80")
<<<<<<< HEAD
if torch.version.hip is not None:
BLOCK = 64
BLOCK_M = 64
BLOCK_N = 64
else:
BLOCK = 128
=======
BLOCK_M = 128
BLOCK_N = 64
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
BLOCK_M = 128
BLOCK_N = 64
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv

View File

@@ -25,11 +25,7 @@ class OutOfResources(Exception):
class Autotuner(KernelInterface):
<<<<<<< HEAD
def __init__(self, fn, arg_names, configs, key, verbose, reset_to_zero, prune_configs_by: Dict = None):
=======
def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, warmup=25, rep=100):
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
def __init__(self, fn, arg_names, configs, key, verbose, reset_to_zero, prune_configs_by: Dict = None, warmup=25, rep=100):
'''
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
@@ -62,12 +58,9 @@ class Autotuner(KernelInterface):
self.perf_model, self.configs_top_k = perf_model, top_k
self.early_config_prune = early_config_prune
self.fn = fn
<<<<<<< HEAD
self.verbose = verbose
=======
self.warmup = warmup
self.rep = rep
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
self.verbose = verbose
def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
@@ -187,11 +180,7 @@ class Config:
return ', '.join(res)
<<<<<<< HEAD
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=False):
=======
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, warmup=25, rep=100):
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=False, warmup=25, rep=100):
"""
Decorator for auto-tuning a :code:`triton.jit`'d function.
@@ -222,21 +211,15 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, warmup=25,
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str]
<<<<<<< HEAD
:param verbose: a boolean that controls whether the best_config for each key is printed
:type verbose: bool
"""
def decorator(fn):
return Autotuner(fn, fn.arg_names, configs, key, verbose, reset_to_zero, prune_configs_by)
=======
:param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25.
:type warmup: int
:param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100.
:type rep: int
:param verbose: a boolean that controls whether the best_config for each key is printed
:type verbose: bool
"""
def decorator(fn):
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, warmup, rep)
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
return Autotuner(fn, fn.arg_names, configs, key, verbose, reset_to_zero, prune_configs_by, warmup, rep)
return decorator

View File

@@ -325,12 +325,8 @@ class JITFunction(KernelInterface[T]):
args_signature = ', '.join(name if dflt == inspect._empty else f'{name} = {dflt}' for name, dflt in zip(self.arg_names, self.arg_defaults))
src = f"""
<<<<<<< HEAD
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
=======
def {self.fn.__name__}({args_signature}, grid=None, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
from ..compiler import compile, CompiledKernel
sig_key = {sig_keys},
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}

View File

@@ -1,125 +0,0 @@
import argparse
import sys
from .._C.libtriton.triton import ir
# import triton.compiler.compiler as tc
from ..compiler.compiler import (get_amdgpu_arch_fulldetails, llir_to_amdgcn_and_hsaco,
llir_to_ptx, optimize_ttgir, optimize_ttir,
ttgir_to_llir, ttir_to_ttgir, CUDA_DEFAULT_WARP_SIZE)
if __name__ == '__main__':
# valid source and target formats
VALID_FORMATS = ['triton-ir', 'triton-gpu-ir', 'llvm-ir', 'ptx', 'amdgcn']
# set up the argument parser
# TODO: conditional requirements
parser = argparse.ArgumentParser()
parser.add_argument('src', help="Source file to compile")
parser.add_argument('--target', required=True,
help="Target format, one of: " + ', '.join(VALID_FORMATS))
parser.add_argument('--sm', type=int, help="Compute capability to compile for")
parser.add_argument('--ptx-version', type=int, help="PTX version to compile for")
parser.add_argument('--gfx', type=str, help="AMDGPU target to compile for")
parser.add_argument('--triple', type=str, help="target triple, for example: amdgcn-amd-amdhsa")
parser.add_argument('--features', type=str, help="target features, for example: +sramecc,-xnack")
parser.add_argument('--num_warps', type=int, help="number of warps to compile ttgir for")
# parse the args
args = parser.parse_args()
# TODO: clean-up and re-use triton.compiler primitive functions
# check for validity of format arguments
if args.target not in VALID_FORMATS:
print("Invalid target format: " + args.target)
sys.exit(0)
# parse source file to MLIR module
context = ir.context()
module = ir.parse_mlir_module(args.src, context)
module.context = context
# optimizer triton-ir
module = optimize_ttir(module, arch=args.sm)
if args.target == 'triton-ir':
print(module.str())
sys.exit(0)
if not args.num_warps:
args.num_warps = 4
# llvm-ir -> amdgcn
if args.target == 'amdgcn':
# auto detect available architecture and features
# if nothing detected, set with default values
arch_details = get_amdgpu_arch_fulldetails()
if not arch_details:
arch_name = ""
arch_triple = "amdgcn-amd-amdhsa"
arch_features = ""
arch_warpsize = 64
else:
arch_triple, arch_name, arch_features, arch_warpsize = arch_details
# stop processing if architecture name is not automatically detected and is not set manually
if not args.gfx and not arch_name:
raise argparse.ArgumentError(None, "Must specify --gfx for AMDGCN compilation")
# rewrite default and automatically detected values with manually provided data
if args.gfx:
arch_name = args.gfx
if args.triple:
arch_triple = args.triple
if args.features:
arch_features = args.features
# triton-ir -> triton-gpu-ir
# use compute_capability == 80
module = ttir_to_ttgir(module, num_warps=args.num_warps, warpsize=arch_warpsize) # num_stages=3, compute_capability=80)
module = optimize_ttgir(module, num_stages=3, arch=args.gfx)
# triton-gpu-ir -> llvm-ir
# use compute_capability == 80
module = ttgir_to_llir(module, extern_libs=None, arch=args.gfx)
# llvm-ir -> amdgcn asm, hsaco binary
module, hsaco_path = llir_to_amdgcn_and_hsaco(module, arch_name, arch_triple, arch_features)
print(hsaco_path)
print(module)
sys.exit(0)
# set arch depending on platform
if args.gfx:
arch = args.gfx
elif args.sm:
arch = args.sm
else:
raise argparse.ArgumentError(None, "Must specify --sm or --gfx for ttgir compilation")
# triton-ir -> triton-gpu-ir
module = ttir_to_ttgir(module, num_warps=args.num_warps, warpsize=CUDA_DEFAULT_WARP_SIZE)
module = optimize_ttgir(module, num_stages=3, arch=arch)
if args.target == 'triton-gpu-ir':
print(module.str())
sys.exit(0)
# triton-gpu-ir -> llvm-ir
module = ttgir_to_llir(module, extern_libs=None, arch=arch)
if args.target == 'llvm-ir':
print(module)
sys.exit(0)
# llvm-ir -> ptx
if args.target == 'ptx':
if not args.sm:
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
if not args.ptx_version:
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")
module = llir_to_ptx(module, arch=args.sm, ptx_version=args.ptx_version)
# llvm-ir -> amdgcn
if args.target == 'amdgcn':
if not args.gfx:
raise argparse.ArgumentError(None, "Must specify --gfx for AMDGCN compilation")
module, hsaco_path = llir_to_amdgcn_and_hsaco(module, args.gfx)
print(module)

View File

@@ -39,16 +39,10 @@ def _fwd_kernel(
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
<<<<<<< HEAD
q_offset = off_hz * stride_qh
kv_offset = off_hz * stride_kh
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset,
=======
qvk_offset = off_hz * stride_qh
Q_block_ptr = tl.make_block_ptr(
base=Q + qvk_offset,
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
@@ -56,26 +50,16 @@ def _fwd_kernel(
order=(1, 0)
)
K_block_ptr = tl.make_block_ptr(
<<<<<<< HEAD
base=K + kv_offset,
shape=(BLOCK_DMODEL, N_CTX + P_SEQ),
=======
base=K + qvk_offset,
shape=(BLOCK_DMODEL, N_CTX),
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
<<<<<<< HEAD
base=V + kv_offset,
shape=(N_CTX + P_SEQ, BLOCK_DMODEL),
=======
base=V + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
@@ -97,11 +81,7 @@ def _fwd_kernel(
q = (q * qk_scale).to(tl.float16)
# loop over k, v and update accumulator
lo = 0
<<<<<<< HEAD
hi = P_SEQ + (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + P_SEQ
=======
hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
for start_n in range(lo, hi, BLOCK_N):
# -- load k, v --
k = tl.load(K_block_ptr)
@@ -109,11 +89,7 @@ def _fwd_kernel(
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
if IS_CAUSAL:
<<<<<<< HEAD
qk = tl.where(P_SEQ + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
=======
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
qk += tl.dot(q, k)
# -- compute scaling constant ---
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
@@ -135,11 +111,7 @@ def _fwd_kernel(
tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
O_block_ptr = tl.make_block_ptr(
<<<<<<< HEAD
base=Out + q_offset,
=======
base=Out + qvk_offset,
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
@@ -152,11 +124,7 @@ def _fwd_kernel(
@triton.jit
def _bwd_preprocess(
Out, DO,
<<<<<<< HEAD
NewDO, Delta,
=======
Delta,
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
@@ -164,12 +132,10 @@ def _bwd_preprocess(
# load
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
<<<<<<< HEAD
=======
# compute
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
tl.store(Delta + off_m, delta)
@@ -233,22 +199,13 @@ def _bwd_kernel(
# load q, k, v, do on-chip
q = tl.load(q_ptrs)
# recompute p = softmax(qk, dim=-1).T
<<<<<<< HEAD
# NOTE: `do` is pre-divided by `l`; no normalization here
qk = tl.dot(q, tl.trans(k))
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
if CAUSAL:
qk = tl.dot(q, tl.trans(k), out_dtype=tl.float32)
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
else:
qk = tl.dot(q, tl.trans(k), out_dtype=tl.float32)
l_i = tl.load(l_ptrs + offs_m_curr)
p = tl.math.exp2(qk * qk_scale - l_i[:, None])
=======
if CAUSAL:
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf"))
else:
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
qk *= qk_scale
l_i = tl.load(l_ptrs + offs_m_curr)
p = tl.math.exp2(qk - l_i[:, None])
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
@@ -492,18 +449,13 @@ empty = torch.empty(128, device="cuda")
class _attention(torch.autograd.Function):
@staticmethod
<<<<<<< HEAD
def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False):
=======
def forward(ctx, q, k, v, causal, sm_scale):
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
BLOCK_M = 128
<<<<<<< HEAD
if torch.version.hip is None:
BLOCK_N = 64 if Lk <= 64 else 32
num_stages = 4 if Lk <= 64 else 3
@@ -514,11 +466,6 @@ class _attention(torch.autograd.Function):
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
P_SEQ = 0 if q.shape[-2] == k.shape[-2] else k.shape[-2] - q.shape[-2]
=======
BLOCK_N = 64
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
num_warps = 4 if Lk <= 64 else 8
_fwd_kernel[grid](
@@ -529,48 +476,36 @@ class _attention(torch.autograd.Function):
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
<<<<<<< HEAD
q.shape[0], q.shape[1], q.shape[2], P_SEQ,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
IS_CAUSAL=causal,
num_warps=num_warps,
num_stages=num_stages)
=======
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
IS_CAUSAL=causal,
num_warps=num_warps,
num_stages=4)
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
ctx.save_for_backward(q, k, v, o, L)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = Lk
ctx.causal = causal
<<<<<<< HEAD
ctx.split_kernel = split_kernel
ctx.P_SEQ = P_SEQ
=======
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
return o
@staticmethod
def backward(ctx, do):
<<<<<<< HEAD
BLOCK = 64
q, k, v, o, l = ctx.saved_tensors
=======
BLOCK = 128
# configuration is not supported
assert(not (ctx.split_kernel and not ctx.causal))
if torch.version.hip is not None:
BLOCK = 64
else:
BLOCK = 128
q, k, v, o, L = ctx.saved_tensors
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
do = do.contiguous()
dq = torch.zeros_like(q)
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
<<<<<<< HEAD
delta = torch.empty_like(L)
do_scaled = torch.empty_like(do)
delta = torch.empty_like(l)
# Figure out what BLOCK size fwd used and adjust num_blocks accordingly.
# If the two are the same, we don't need this but the bwd pass block size
# is smaller than the fwd so we need this scaling to ensure we loop over all
@@ -588,8 +523,7 @@ class _attention(torch.autograd.Function):
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
l,
delta,
L, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
@@ -597,15 +531,16 @@ class _attention(torch.autograd.Function):
block_scale * ctx.grid[0],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4,
CAUSAL=ctx.causal,
num_stages=1,
)
else :
dq = torch.zeros_like(q)
_bwd_kernel_dk_dv[(block_scale * ctx.grid[0], ctx.grid[1])](
q, k, v, ctx.sm_scale,
o, do_scaled,
dk, dv,
l,
delta,
L, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
@@ -618,8 +553,7 @@ class _attention(torch.autograd.Function):
q, k, v, ctx.sm_scale,
o, do_scaled,
dq,
l,
delta,
L, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
@@ -630,36 +564,10 @@ class _attention(torch.autograd.Function):
)
# print(h.asm["ttgir"])
return dq, dk, dv, None, None, None
=======
delta = torch.empty_like(L)
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do,
delta,
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do,
dq, dk, dv,
L, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
ctx.grid[0],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
CAUSAL=ctx.causal,
num_stages=1,
)
return dq, dk, dv, None, None
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
attention = _attention.apply
<<<<<<< HEAD
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, P_SEQ',
[(4, 48, 1024, 64, 128),
(4, 48, 2048, 64, 128),
@@ -702,16 +610,6 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, P_SEQ, dtype=torch.float16):
v = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
sm_scale = q.shape[-1] ** (-0.5)
split_kernel = True
=======
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(6, 9, 1024, 64)])
@pytest.mark.parametrize('causal', [False, True])
def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
sm_scale = 0.5
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX + P_SEQ), device="cuda"), diagonal=P_SEQ)
@@ -724,13 +622,8 @@ def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
<<<<<<< HEAD
# # triton implementation
tri_out = attention(q, k, v, causal, sm_scale, split_kernel)
=======
# triton implementation
tri_out = attention(q, k, v, causal, sm_scale).half()
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
@@ -771,11 +664,7 @@ configs = [triton.testing.Benchmark(
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode, 'causal': causal}
<<<<<<< HEAD
) for mode in ['fwd', 'bwd'] for causal in [True, False]]
=======
) for mode in ['fwd', 'bwd'] for causal in [False, True]]
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
@triton.testing.perf_report(configs)
@@ -793,11 +682,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
sm_scale = 1.3
<<<<<<< HEAD
fn = lambda: attention(q, k, v, causal, sm_scale, split_kernel)
=======
fn = lambda: attention(q, k, v, causal, sm_scale)
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)