mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Fix merge conflicts
This commit is contained in:
@@ -1052,126 +1052,108 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("create_shl",
|
||||
[](TritonOpBuilder &self, mlir::Value &lhs,
|
||||
mlir::Value &rhs) -> mlir::Value {
|
||||
<<<<<<< HEAD
|
||||
auto loc = self.getUnknownLoc();
|
||||
#ifdef USE_ROCM
|
||||
mlir::Type elementType = getElementTypeOrSelf(lhs.getType());
|
||||
unsigned typeWidth = elementType.getIntOrFloatBitWidth();
|
||||
auto constValue = self.create<mlir::arith::ConstantIntOp>(
|
||||
loc, typeWidth, elementType);
|
||||
typeWidth, elementType);
|
||||
auto zeroConst =
|
||||
self.create<mlir::arith::ConstantIntOp>(loc, 0, elementType);
|
||||
self.create<mlir::arith::ConstantIntOp>(0, elementType);
|
||||
if (lhs.getType().isIntOrIndex()) {
|
||||
auto cmpValue = self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::ult, rhs, constValue);
|
||||
auto shiftValue = self.create<mlir::arith::ShLIOp>(loc, lhs, rhs);
|
||||
mlir::arith::CmpIPredicate::ult, rhs, constValue);
|
||||
auto shiftValue = self.create<mlir::arith::ShLIOp>(lhs, rhs);
|
||||
return mlir::Value(self.create<mlir::arith::SelectOp>(
|
||||
loc, cmpValue, shiftValue, zeroConst));
|
||||
cmpValue, shiftValue, zeroConst));
|
||||
} else {
|
||||
auto splatValue = self.create<mlir::tensor::SplatOp>(
|
||||
loc, lhs.getType(), constValue);
|
||||
lhs.getType(), constValue);
|
||||
auto zeroValue = self.create<mlir::tensor::SplatOp>(
|
||||
loc, lhs.getType(), zeroConst);
|
||||
lhs.getType(), zeroConst);
|
||||
auto cmpValue = self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::ult, rhs, splatValue);
|
||||
auto shiftValue = self.create<mlir::arith::ShLIOp>(loc, lhs, rhs);
|
||||
mlir::arith::CmpIPredicate::ult, rhs, splatValue);
|
||||
auto shiftValue = self.create<mlir::arith::ShLIOp>(lhs, rhs);
|
||||
return mlir::Value(self.create<mlir::arith::SelectOp>(
|
||||
loc, cmpValue, shiftValue, zeroValue));
|
||||
cmpValue, shiftValue, zeroValue));
|
||||
}
|
||||
#else
|
||||
return mlir::Value(
|
||||
self.create<mlir::arith::ShLIOp>(loc, lhs, rhs));
|
||||
#endif
|
||||
=======
|
||||
return mlir::Value(self.create<mlir::arith::ShLIOp>(lhs, rhs));
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
#endif
|
||||
})
|
||||
.def("create_lshr",
|
||||
[](TritonOpBuilder &self, mlir::Value &lhs,
|
||||
mlir::Value &rhs) -> mlir::Value {
|
||||
<<<<<<< HEAD
|
||||
auto loc = self.getUnknownLoc();
|
||||
#ifdef USE_ROCM
|
||||
mlir::Type elementType = getElementTypeOrSelf(lhs.getType());
|
||||
unsigned typeWidth = elementType.getIntOrFloatBitWidth();
|
||||
auto constValue = self.create<mlir::arith::ConstantIntOp>(
|
||||
loc, typeWidth, elementType);
|
||||
typeWidth, elementType);
|
||||
auto zeroConst =
|
||||
self.create<mlir::arith::ConstantIntOp>(loc, 0, elementType);
|
||||
self.create<mlir::arith::ConstantIntOp>(0, elementType);
|
||||
if (lhs.getType().isIntOrIndex()) {
|
||||
auto cmpValue = self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::ult, rhs, constValue);
|
||||
auto shiftValue = self.create<mlir::arith::ShRUIOp>(loc, lhs, rhs);
|
||||
mlir::arith::CmpIPredicate::ult, rhs, constValue);
|
||||
auto shiftValue = self.create<mlir::arith::ShRUIOp>(lhs, rhs);
|
||||
return mlir::Value(self.create<mlir::arith::SelectOp>(
|
||||
loc, cmpValue, shiftValue, zeroConst));
|
||||
cmpValue, shiftValue, zeroConst));
|
||||
} else {
|
||||
auto splatValue = self.create<mlir::tensor::SplatOp>(
|
||||
loc, lhs.getType(), constValue);
|
||||
lhs.getType(), constValue);
|
||||
auto zeroValue = self.create<mlir::tensor::SplatOp>(
|
||||
loc, lhs.getType(), zeroConst);
|
||||
lhs.getType(), zeroConst);
|
||||
auto cmpValue = self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::ult, rhs, splatValue);
|
||||
auto shiftValue = self.create<mlir::arith::ShRUIOp>(loc, lhs, rhs);
|
||||
mlir::arith::CmpIPredicate::ult, rhs, splatValue);
|
||||
auto shiftValue = self.create<mlir::arith::ShRUIOp>(lhs, rhs);
|
||||
return mlir::Value(self.create<mlir::arith::SelectOp>(
|
||||
loc, cmpValue, shiftValue, zeroValue));
|
||||
cmpValue, shiftValue, zeroValue));
|
||||
}
|
||||
#else
|
||||
return mlir::Value(
|
||||
self.create<mlir::arith::ShRUIOp>(loc, lhs, rhs));
|
||||
#endif
|
||||
=======
|
||||
return mlir::Value(self.create<mlir::arith::ShRUIOp>(lhs, rhs));
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
#endif
|
||||
})
|
||||
.def("create_ashr",
|
||||
[](TritonOpBuilder &self, mlir::Value &lhs,
|
||||
mlir::Value &rhs) -> mlir::Value {
|
||||
<<<<<<< HEAD
|
||||
auto loc = self.getUnknownLoc();
|
||||
#ifdef USE_ROCM
|
||||
mlir::Type elementType = getElementTypeOrSelf(lhs.getType());
|
||||
unsigned typeWidth = elementType.getIntOrFloatBitWidth();
|
||||
auto constValue = self.create<mlir::arith::ConstantIntOp>(
|
||||
loc, typeWidth, elementType);
|
||||
typeWidth, elementType);
|
||||
auto zeroConst =
|
||||
self.create<mlir::arith::ConstantIntOp>(loc, 0, elementType);
|
||||
self.create<mlir::arith::ConstantIntOp>(0, elementType);
|
||||
uint64_t ones_val = 0xFFFFFFFFFFFFFFFF;
|
||||
auto onesConst =
|
||||
self.create<mlir::arith::ConstantIntOp>(loc, ones_val, elementType);
|
||||
self.create<mlir::arith::ConstantIntOp>(ones_val, elementType);
|
||||
if (lhs.getType().isIntOrIndex()) {
|
||||
auto negativeCmpValue = self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::slt, lhs, zeroConst);
|
||||
mlir::arith::CmpIPredicate::slt, lhs, zeroConst);
|
||||
auto otherValue = mlir::Value(self.create<mlir::arith::SelectOp>(
|
||||
loc, negativeCmpValue, onesConst, zeroConst));
|
||||
negativeCmpValue, onesConst, zeroConst));
|
||||
auto cmpValue = self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::ult, rhs, constValue);
|
||||
auto shiftValue = self.create<mlir::arith::ShRSIOp>(loc, lhs, rhs);
|
||||
mlir::arith::CmpIPredicate::ult, rhs, constValue);
|
||||
auto shiftValue = self.create<mlir::arith::ShRSIOp>(lhs, rhs);
|
||||
return mlir::Value(self.create<mlir::arith::SelectOp>(
|
||||
loc, cmpValue, shiftValue, otherValue));
|
||||
cmpValue, shiftValue, otherValue));
|
||||
} else {
|
||||
auto splatValue = self.create<mlir::tensor::SplatOp>(
|
||||
loc, lhs.getType(), constValue);
|
||||
lhs.getType(), constValue);
|
||||
auto zeroValue = self.create<mlir::tensor::SplatOp>(
|
||||
loc, lhs.getType(), zeroConst);
|
||||
lhs.getType(), zeroConst);
|
||||
auto onesValue = self.create<mlir::tensor::SplatOp>(
|
||||
loc, lhs.getType(), onesConst);
|
||||
lhs.getType(), onesConst);
|
||||
auto negativeCmpValue = self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::slt, lhs, zeroValue);
|
||||
mlir::arith::CmpIPredicate::slt, lhs, zeroValue);
|
||||
auto otherValue = mlir::Value(self.create<mlir::arith::SelectOp>(
|
||||
loc, negativeCmpValue, onesValue, zeroValue));
|
||||
negativeCmpValue, onesValue, zeroValue));
|
||||
auto cmpValue = self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::ult, rhs, splatValue);
|
||||
auto shiftValue = self.create<mlir::arith::ShRSIOp>(loc, lhs, rhs);
|
||||
mlir::arith::CmpIPredicate::ult, rhs, splatValue);
|
||||
auto shiftValue = self.create<mlir::arith::ShRSIOp>(lhs, rhs);
|
||||
return mlir::Value(self.create<mlir::arith::SelectOp>(
|
||||
loc, cmpValue, shiftValue, otherValue));
|
||||
cmpValue, shiftValue, otherValue));
|
||||
}
|
||||
#else
|
||||
return mlir::Value(
|
||||
self.create<mlir::arith::ShRSIOp>(loc, lhs, rhs));
|
||||
#endif
|
||||
=======
|
||||
return mlir::Value(self.create<mlir::arith::ShRSIOp>(lhs, rhs));
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
#endif
|
||||
})
|
||||
// AddPtr (similar to GEP)
|
||||
.def("create_addptr",
|
||||
|
||||
@@ -14,14 +14,9 @@ from typing import Any, Tuple
|
||||
from .._C.libtriton.triton import (add_external_libs, compile_ptx_to_cubin,
|
||||
get_shared_memory_size, ir,
|
||||
translate_llvmir_to_hsaco, translate_llvmir_to_ptx,
|
||||
<<<<<<< HEAD
|
||||
translate_triton_gpu_to_llvmir, get_arch_info,
|
||||
get_warp_size)
|
||||
from ..common.backend import get_backend
|
||||
=======
|
||||
translate_triton_gpu_to_llvmir)
|
||||
from ..common.backend import get_backend, path_to_ptxas
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
# from ..runtime import driver, jit, JITFunction
|
||||
# TODO: runtime.errors
|
||||
from ..runtime.autotuner import OutOfResources
|
||||
|
||||
@@ -294,15 +294,12 @@ class _attention(torch.autograd.Function):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 8:
|
||||
raise RuntimeError("Flash attention currently only supported for compute capability >= 80")
|
||||
<<<<<<< HEAD
|
||||
if torch.version.hip is not None:
|
||||
BLOCK = 64
|
||||
BLOCK_M = 64
|
||||
BLOCK_N = 64
|
||||
else:
|
||||
BLOCK = 128
|
||||
=======
|
||||
BLOCK_M = 128
|
||||
BLOCK_N = 64
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
BLOCK_M = 128
|
||||
BLOCK_N = 64
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
|
||||
@@ -25,11 +25,7 @@ class OutOfResources(Exception):
|
||||
|
||||
|
||||
class Autotuner(KernelInterface):
|
||||
<<<<<<< HEAD
|
||||
def __init__(self, fn, arg_names, configs, key, verbose, reset_to_zero, prune_configs_by: Dict = None):
|
||||
=======
|
||||
def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, warmup=25, rep=100):
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
def __init__(self, fn, arg_names, configs, key, verbose, reset_to_zero, prune_configs_by: Dict = None, warmup=25, rep=100):
|
||||
'''
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
@@ -62,12 +58,9 @@ class Autotuner(KernelInterface):
|
||||
self.perf_model, self.configs_top_k = perf_model, top_k
|
||||
self.early_config_prune = early_config_prune
|
||||
self.fn = fn
|
||||
<<<<<<< HEAD
|
||||
self.verbose = verbose
|
||||
=======
|
||||
self.warmup = warmup
|
||||
self.rep = rep
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
self.verbose = verbose
|
||||
|
||||
def _bench(self, *args, config, **meta):
|
||||
# check for conflicts, i.e. meta-parameters both provided
|
||||
@@ -187,11 +180,7 @@ class Config:
|
||||
return ', '.join(res)
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=False):
|
||||
=======
|
||||
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, warmup=25, rep=100):
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=False, warmup=25, rep=100):
|
||||
"""
|
||||
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
||||
|
||||
@@ -222,21 +211,15 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, warmup=25,
|
||||
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
|
||||
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
||||
:type reset_to_zero: list[str]
|
||||
<<<<<<< HEAD
|
||||
:param verbose: a boolean that controls whether the best_config for each key is printed
|
||||
:type verbose: bool
|
||||
"""
|
||||
def decorator(fn):
|
||||
return Autotuner(fn, fn.arg_names, configs, key, verbose, reset_to_zero, prune_configs_by)
|
||||
=======
|
||||
:param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25.
|
||||
:type warmup: int
|
||||
:param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100.
|
||||
:type rep: int
|
||||
:param verbose: a boolean that controls whether the best_config for each key is printed
|
||||
:type verbose: bool
|
||||
"""
|
||||
def decorator(fn):
|
||||
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, warmup, rep)
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
return Autotuner(fn, fn.arg_names, configs, key, verbose, reset_to_zero, prune_configs_by, warmup, rep)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@@ -325,12 +325,8 @@ class JITFunction(KernelInterface[T]):
|
||||
args_signature = ', '.join(name if dflt == inspect._empty else f'{name} = {dflt}' for name, dflt in zip(self.arg_names, self.arg_defaults))
|
||||
|
||||
src = f"""
|
||||
<<<<<<< HEAD
|
||||
|
||||
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
=======
|
||||
def {self.fn.__name__}({args_signature}, grid=None, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
from ..compiler import compile, CompiledKernel
|
||||
sig_key = {sig_keys},
|
||||
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
|
||||
|
||||
@@ -1,125 +0,0 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from .._C.libtriton.triton import ir
|
||||
# import triton.compiler.compiler as tc
|
||||
from ..compiler.compiler import (get_amdgpu_arch_fulldetails, llir_to_amdgcn_and_hsaco,
|
||||
llir_to_ptx, optimize_ttgir, optimize_ttir,
|
||||
ttgir_to_llir, ttir_to_ttgir, CUDA_DEFAULT_WARP_SIZE)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# valid source and target formats
|
||||
VALID_FORMATS = ['triton-ir', 'triton-gpu-ir', 'llvm-ir', 'ptx', 'amdgcn']
|
||||
|
||||
# set up the argument parser
|
||||
# TODO: conditional requirements
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('src', help="Source file to compile")
|
||||
parser.add_argument('--target', required=True,
|
||||
help="Target format, one of: " + ', '.join(VALID_FORMATS))
|
||||
parser.add_argument('--sm', type=int, help="Compute capability to compile for")
|
||||
parser.add_argument('--ptx-version', type=int, help="PTX version to compile for")
|
||||
parser.add_argument('--gfx', type=str, help="AMDGPU target to compile for")
|
||||
parser.add_argument('--triple', type=str, help="target triple, for example: amdgcn-amd-amdhsa")
|
||||
parser.add_argument('--features', type=str, help="target features, for example: +sramecc,-xnack")
|
||||
parser.add_argument('--num_warps', type=int, help="number of warps to compile ttgir for")
|
||||
|
||||
# parse the args
|
||||
args = parser.parse_args()
|
||||
|
||||
# TODO: clean-up and re-use triton.compiler primitive functions
|
||||
# check for validity of format arguments
|
||||
if args.target not in VALID_FORMATS:
|
||||
print("Invalid target format: " + args.target)
|
||||
sys.exit(0)
|
||||
|
||||
# parse source file to MLIR module
|
||||
context = ir.context()
|
||||
module = ir.parse_mlir_module(args.src, context)
|
||||
module.context = context
|
||||
|
||||
# optimizer triton-ir
|
||||
module = optimize_ttir(module, arch=args.sm)
|
||||
if args.target == 'triton-ir':
|
||||
print(module.str())
|
||||
sys.exit(0)
|
||||
|
||||
if not args.num_warps:
|
||||
args.num_warps = 4
|
||||
|
||||
# llvm-ir -> amdgcn
|
||||
if args.target == 'amdgcn':
|
||||
# auto detect available architecture and features
|
||||
# if nothing detected, set with default values
|
||||
arch_details = get_amdgpu_arch_fulldetails()
|
||||
if not arch_details:
|
||||
arch_name = ""
|
||||
arch_triple = "amdgcn-amd-amdhsa"
|
||||
arch_features = ""
|
||||
arch_warpsize = 64
|
||||
else:
|
||||
arch_triple, arch_name, arch_features, arch_warpsize = arch_details
|
||||
|
||||
# stop processing if architecture name is not automatically detected and is not set manually
|
||||
if not args.gfx and not arch_name:
|
||||
raise argparse.ArgumentError(None, "Must specify --gfx for AMDGCN compilation")
|
||||
|
||||
# rewrite default and automatically detected values with manually provided data
|
||||
if args.gfx:
|
||||
arch_name = args.gfx
|
||||
if args.triple:
|
||||
arch_triple = args.triple
|
||||
if args.features:
|
||||
arch_features = args.features
|
||||
|
||||
# triton-ir -> triton-gpu-ir
|
||||
# use compute_capability == 80
|
||||
module = ttir_to_ttgir(module, num_warps=args.num_warps, warpsize=arch_warpsize) # num_stages=3, compute_capability=80)
|
||||
module = optimize_ttgir(module, num_stages=3, arch=args.gfx)
|
||||
# triton-gpu-ir -> llvm-ir
|
||||
# use compute_capability == 80
|
||||
module = ttgir_to_llir(module, extern_libs=None, arch=args.gfx)
|
||||
# llvm-ir -> amdgcn asm, hsaco binary
|
||||
module, hsaco_path = llir_to_amdgcn_and_hsaco(module, arch_name, arch_triple, arch_features)
|
||||
|
||||
print(hsaco_path)
|
||||
print(module)
|
||||
sys.exit(0)
|
||||
|
||||
# set arch depending on platform
|
||||
if args.gfx:
|
||||
arch = args.gfx
|
||||
elif args.sm:
|
||||
arch = args.sm
|
||||
else:
|
||||
raise argparse.ArgumentError(None, "Must specify --sm or --gfx for ttgir compilation")
|
||||
|
||||
# triton-ir -> triton-gpu-ir
|
||||
module = ttir_to_ttgir(module, num_warps=args.num_warps, warpsize=CUDA_DEFAULT_WARP_SIZE)
|
||||
module = optimize_ttgir(module, num_stages=3, arch=arch)
|
||||
if args.target == 'triton-gpu-ir':
|
||||
print(module.str())
|
||||
sys.exit(0)
|
||||
|
||||
# triton-gpu-ir -> llvm-ir
|
||||
module = ttgir_to_llir(module, extern_libs=None, arch=arch)
|
||||
if args.target == 'llvm-ir':
|
||||
print(module)
|
||||
sys.exit(0)
|
||||
|
||||
# llvm-ir -> ptx
|
||||
if args.target == 'ptx':
|
||||
if not args.sm:
|
||||
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
|
||||
if not args.ptx_version:
|
||||
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")
|
||||
module = llir_to_ptx(module, arch=args.sm, ptx_version=args.ptx_version)
|
||||
|
||||
# llvm-ir -> amdgcn
|
||||
if args.target == 'amdgcn':
|
||||
if not args.gfx:
|
||||
raise argparse.ArgumentError(None, "Must specify --gfx for AMDGCN compilation")
|
||||
module, hsaco_path = llir_to_amdgcn_and_hsaco(module, args.gfx)
|
||||
|
||||
print(module)
|
||||
@@ -39,16 +39,10 @@ def _fwd_kernel(
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
<<<<<<< HEAD
|
||||
q_offset = off_hz * stride_qh
|
||||
kv_offset = off_hz * stride_kh
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + q_offset,
|
||||
=======
|
||||
qvk_offset = off_hz * stride_qh
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + qvk_offset,
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
@@ -56,26 +50,16 @@ def _fwd_kernel(
|
||||
order=(1, 0)
|
||||
)
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
<<<<<<< HEAD
|
||||
base=K + kv_offset,
|
||||
shape=(BLOCK_DMODEL, N_CTX + P_SEQ),
|
||||
=======
|
||||
base=K + qvk_offset,
|
||||
shape=(BLOCK_DMODEL, N_CTX),
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
order=(0, 1)
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
<<<<<<< HEAD
|
||||
base=V + kv_offset,
|
||||
shape=(N_CTX + P_SEQ, BLOCK_DMODEL),
|
||||
=======
|
||||
base=V + qvk_offset,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
strides=(stride_vk, stride_vn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
@@ -97,11 +81,7 @@ def _fwd_kernel(
|
||||
q = (q * qk_scale).to(tl.float16)
|
||||
# loop over k, v and update accumulator
|
||||
lo = 0
|
||||
<<<<<<< HEAD
|
||||
hi = P_SEQ + (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + P_SEQ
|
||||
=======
|
||||
hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
for start_n in range(lo, hi, BLOCK_N):
|
||||
# -- load k, v --
|
||||
k = tl.load(K_block_ptr)
|
||||
@@ -109,11 +89,7 @@ def _fwd_kernel(
|
||||
# -- compute qk ---
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
if IS_CAUSAL:
|
||||
<<<<<<< HEAD
|
||||
qk = tl.where(P_SEQ + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
=======
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
qk += tl.dot(q, k)
|
||||
# -- compute scaling constant ---
|
||||
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
|
||||
@@ -135,11 +111,7 @@ def _fwd_kernel(
|
||||
tl.store(l_ptrs, m_i + tl.math.log2(l_i))
|
||||
# write back O
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
<<<<<<< HEAD
|
||||
base=Out + q_offset,
|
||||
=======
|
||||
base=Out + qvk_offset,
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
@@ -152,11 +124,7 @@ def _fwd_kernel(
|
||||
@triton.jit
|
||||
def _bwd_preprocess(
|
||||
Out, DO,
|
||||
<<<<<<< HEAD
|
||||
NewDO, Delta,
|
||||
=======
|
||||
Delta,
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
||||
):
|
||||
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
@@ -164,12 +132,10 @@ def _bwd_preprocess(
|
||||
# load
|
||||
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
# compute
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
delta = tl.sum(o * do, axis=1)
|
||||
# write-back
|
||||
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
|
||||
tl.store(Delta + off_m, delta)
|
||||
|
||||
|
||||
@@ -233,22 +199,13 @@ def _bwd_kernel(
|
||||
# load q, k, v, do on-chip
|
||||
q = tl.load(q_ptrs)
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
<<<<<<< HEAD
|
||||
# NOTE: `do` is pre-divided by `l`; no normalization here
|
||||
qk = tl.dot(q, tl.trans(k))
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
||||
if CAUSAL:
|
||||
qk = tl.dot(q, tl.trans(k), out_dtype=tl.float32)
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
||||
else:
|
||||
qk = tl.dot(q, tl.trans(k), out_dtype=tl.float32)
|
||||
l_i = tl.load(l_ptrs + offs_m_curr)
|
||||
p = tl.math.exp2(qk * qk_scale - l_i[:, None])
|
||||
=======
|
||||
if CAUSAL:
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf"))
|
||||
else:
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, tl.trans(k))
|
||||
qk *= qk_scale
|
||||
l_i = tl.load(l_ptrs + offs_m_curr)
|
||||
p = tl.math.exp2(qk - l_i[:, None])
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
# compute dv
|
||||
do = tl.load(do_ptrs)
|
||||
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
|
||||
@@ -492,18 +449,13 @@ empty = torch.empty(128, device="cuda")
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
<<<<<<< HEAD
|
||||
def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False):
|
||||
=======
|
||||
def forward(ctx, q, k, v, causal, sm_scale):
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
o = torch.empty_like(q)
|
||||
BLOCK_M = 128
|
||||
<<<<<<< HEAD
|
||||
if torch.version.hip is None:
|
||||
BLOCK_N = 64 if Lk <= 64 else 32
|
||||
num_stages = 4 if Lk <= 64 else 3
|
||||
@@ -514,11 +466,6 @@ class _attention(torch.autograd.Function):
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
P_SEQ = 0 if q.shape[-2] == k.shape[-2] else k.shape[-2] - q.shape[-2]
|
||||
=======
|
||||
BLOCK_N = 64
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
_fwd_kernel[grid](
|
||||
@@ -529,48 +476,36 @@ class _attention(torch.autograd.Function):
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
<<<<<<< HEAD
|
||||
q.shape[0], q.shape[1], q.shape[2], P_SEQ,
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
|
||||
IS_CAUSAL=causal,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages)
|
||||
=======
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
|
||||
IS_CAUSAL=causal,
|
||||
num_warps=num_warps,
|
||||
num_stages=4)
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, L)
|
||||
ctx.grid = grid
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.BLOCK_DMODEL = Lk
|
||||
ctx.causal = causal
|
||||
<<<<<<< HEAD
|
||||
ctx.split_kernel = split_kernel
|
||||
ctx.P_SEQ = P_SEQ
|
||||
=======
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
<<<<<<< HEAD
|
||||
BLOCK = 64
|
||||
q, k, v, o, l = ctx.saved_tensors
|
||||
=======
|
||||
BLOCK = 128
|
||||
# configuration is not supported
|
||||
assert(not (ctx.split_kernel and not ctx.causal))
|
||||
if torch.version.hip is not None:
|
||||
BLOCK = 64
|
||||
else:
|
||||
BLOCK = 128
|
||||
q, k, v, o, L = ctx.saved_tensors
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
do = do.contiguous()
|
||||
dq = torch.zeros_like(q)
|
||||
dq = torch.zeros_like(q, dtype=torch.float32)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
<<<<<<< HEAD
|
||||
delta = torch.empty_like(L)
|
||||
do_scaled = torch.empty_like(do)
|
||||
delta = torch.empty_like(l)
|
||||
# Figure out what BLOCK size fwd used and adjust num_blocks accordingly.
|
||||
# If the two are the same, we don't need this but the bwd pass block size
|
||||
# is smaller than the fwd so we need this scaling to ensure we loop over all
|
||||
@@ -588,8 +523,7 @@ class _attention(torch.autograd.Function):
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dq, dk, dv,
|
||||
l,
|
||||
delta,
|
||||
L, delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
@@ -597,15 +531,16 @@ class _attention(torch.autograd.Function):
|
||||
block_scale * ctx.grid[0],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4,
|
||||
CAUSAL=ctx.causal,
|
||||
num_stages=1,
|
||||
)
|
||||
else :
|
||||
dq = torch.zeros_like(q)
|
||||
_bwd_kernel_dk_dv[(block_scale * ctx.grid[0], ctx.grid[1])](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dk, dv,
|
||||
l,
|
||||
delta,
|
||||
L, delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
@@ -618,8 +553,7 @@ class _attention(torch.autograd.Function):
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dq,
|
||||
l,
|
||||
delta,
|
||||
L, delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
@@ -630,36 +564,10 @@ class _attention(torch.autograd.Function):
|
||||
)
|
||||
# print(h.asm["ttgir"])
|
||||
return dq, dk, dv, None, None, None
|
||||
=======
|
||||
delta = torch.empty_like(L)
|
||||
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
||||
o, do,
|
||||
delta,
|
||||
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
_bwd_kernel[(ctx.grid[1],)](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do,
|
||||
dq, dk, dv,
|
||||
L, delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
ctx.grid[0],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
||||
CAUSAL=ctx.causal,
|
||||
num_stages=1,
|
||||
)
|
||||
return dq, dk, dv, None, None
|
||||
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
|
||||
attention = _attention.apply
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, P_SEQ',
|
||||
[(4, 48, 1024, 64, 128),
|
||||
(4, 48, 2048, 64, 128),
|
||||
@@ -702,16 +610,6 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, P_SEQ, dtype=torch.float16):
|
||||
v = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
sm_scale = q.shape[-1] ** (-0.5)
|
||||
split_kernel = True
|
||||
=======
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(6, 9, 1024, 64)])
|
||||
@pytest.mark.parametrize('causal', [False, True])
|
||||
def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
sm_scale = 0.5
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
dout = torch.randn_like(q)
|
||||
# reference implementation
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX + P_SEQ), device="cuda"), diagonal=P_SEQ)
|
||||
@@ -724,13 +622,8 @@ def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
<<<<<<< HEAD
|
||||
# # triton implementation
|
||||
tri_out = attention(q, k, v, causal, sm_scale, split_kernel)
|
||||
=======
|
||||
# triton implementation
|
||||
tri_out = attention(q, k, v, causal, sm_scale).half()
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
@@ -771,11 +664,7 @@ configs = [triton.testing.Benchmark(
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
||||
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode, 'causal': causal}
|
||||
<<<<<<< HEAD
|
||||
) for mode in ['fwd', 'bwd'] for causal in [True, False]]
|
||||
=======
|
||||
) for mode in ['fwd', 'bwd'] for causal in [False, True]]
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
@@ -793,11 +682,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype
|
||||
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
sm_scale = 1.3
|
||||
<<<<<<< HEAD
|
||||
fn = lambda: attention(q, k, v, causal, sm_scale, split_kernel)
|
||||
=======
|
||||
fn = lambda: attention(q, k, v, causal, sm_scale)
|
||||
>>>>>>> 5df904233c11a65bd131ead7268f84cca7804275
|
||||
if mode == 'bwd':
|
||||
o = fn()
|
||||
do = torch.randn_like(o)
|
||||
|
||||
Reference in New Issue
Block a user