Merge commit 'cb3d79a185e40c9d8a579bea07747a8a8d157d52' into ifu-231117

Conflicts:
	lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
	lib/Dialect/TritonGPU/IR/Dialect.cpp
	python/setup.py
	python/test/unit/language/assert_helper.py
	python/test/unit/operators/test_flash_attention.py
	python/test/unit/runtime/test_subproc.py
	python/triton/compiler/compiler.py
	python/triton/language/semantic.py
	python/triton/runtime/autotuner.py
	python/triton/runtime/jit.py
	python/tutorials/03-matrix-multiplication.py
	python/tutorials/05-layer-norm.py
	python/tutorials/06-fused-attention.py
	python/tutorials/11-grouped-gemm.py
	test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
Jason Furmanek
2023-11-17 20:42:12 +00:00
179 changed files with 10116 additions and 6835 deletions

View File

@@ -4,8 +4,8 @@ import triton.language as tl
# triton kernel
@triton.jit
def kernel(X, stride_xm,
Z, stride_zn,
def kernel(X, stride_xm, #
Z, stride_zn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)

View File

@@ -10,4 +10,4 @@ def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr):
X = torch.randn(1, device="cuda")
pgm = kernel[(1,)](X, 1, 1, BLOCK=1024)
pgm = kernel[(1, )](X, 1, 1, BLOCK=1024)

View File

@@ -2,7 +2,14 @@
[build-system]
requires = ["setuptools>=40.8.0", "wheel", "cmake>=3.18", "ninja>=1.11.1"]
# We're incrementally switching from autopep8 to ruff.
[tool.autopep8]
aggressive = 1
ignore = "E501,E701,E731,W690"
ignore = "E501,E701,E731,W690,W503"
max_line_length = 88
[tool.ruff]
line-length = 120
[tool.ruff.lint]
ignore = ["E501", "E701", "E731", "E741"]

View File

@@ -55,6 +55,7 @@ class Package(NamedTuple):
lib_flag: str
syspath_var_name: str
# pybind11
@@ -63,6 +64,7 @@ def get_pybind11_package_info():
url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.11.1.tar.gz"
return Package("pybind11", name, url, "PYBIND11_INCLUDE_DIR", "", "PYBIND11_SYSPATH")
# llvm
@@ -74,6 +76,8 @@ def get_llvm_package_info():
arch = 'arm64'
if system == "Darwin":
arch = platform.machine()
if arch == "x86_64":
arch = "x64"
system_suffix = f"macos-{arch}"
elif system == "Linux":
# TODO: arm64
@@ -84,7 +88,7 @@ def get_llvm_package_info():
return Package("llvm", "LLVM-C.lib", "", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
# use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
# release_suffix = "assert" if use_assert_enabled_llvm else "release"
rev = "b1115f8c"
rev = "49af6502"
name = f"llvm-{rev}-{system_suffix}"
url = f"https://tritonlang.blob.core.windows.net/llvm-builds/{name}.tar.gz"
return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
@@ -119,10 +123,13 @@ def get_thirdparty_packages(triton_cache_path):
thirdparty_cmake_args.append(f"-D{p.lib_flag}={package_dir}/lib")
return thirdparty_cmake_args
# ---- package data ---
def download_and_copy(src_path, version, url_func):
def download_and_copy(src_path, variable, version, url_func):
if variable in os.environ:
return
base_dir = os.path.dirname(__file__)
arch = platform.machine()
if arch == "x86_64":
@@ -148,7 +155,7 @@ def download_and_copy(src_path, version, url_func):
src_path = os.path.join(temp_dir, src_path)
os.makedirs(os.path.split(dst_path)[0], exist_ok=True)
shutil.copy(src_path, dst_path)
return dst_suffix
# ---- cmake extension ----
@@ -167,18 +174,21 @@ def get_cmake_dir():
class CMakeClean(clean):
def initialize_options(self):
clean.initialize_options(self)
self.build_temp = get_cmake_dir()
class CMakeBuildPy(build_py):
def run(self) -> None:
self.run_command('build_ext')
return super().run()
class CMakeExtension(Extension):
def __init__(self, name, path, sourcedir=""):
Extension.__init__(self, name, sources=[])
self.sourcedir = os.path.abspath(sourcedir)
@@ -201,7 +211,8 @@ class CMakeBuild(build_ext):
try:
out = subprocess.check_output(["cmake", "--version"])
except OSError:
raise RuntimeError("CMake must be installed to build the following extensions: " + ", ".join(e.name for e in self.extensions))
raise RuntimeError("CMake must be installed to build the following extensions: " +
", ".join(e.name for e in self.extensions))
match = re.search(r"version\s*(?P<major>\d+)\.(?P<minor>\d+)([\d.]+)?", out.decode())
cmake_major, cmake_minor = int(match.group("major")), int(match.group("minor"))
@@ -228,8 +239,10 @@ class CMakeBuild(build_ext):
# python directories
python_include_dir = sysconfig.get_path("platinclude")
cmake_args = [
"-G", "Ninja", # Ninja is much faster than make
"-DCMAKE_MAKE_PROGRAM=" + ninja_dir, # Pass explicit path to ninja otherwise cmake may cache a temporary path
"-G",
"Ninja", # Ninja is much faster than make
"-DCMAKE_MAKE_PROGRAM=" +
ninja_dir, # Pass explicit path to ninja otherwise cmake may cache a temporary path
"-DCMAKE_EXPORT_COMPILE_COMMANDS=ON",
"-DLLVM_ENABLE_WERROR=ON",
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
@@ -263,12 +276,28 @@ class CMakeBuild(build_ext):
build_args += ['-j' + max_jobs]
if check_env_flag("TRITON_BUILD_WITH_CLANG_LLD"):
cmake_args += ["-DCMAKE_C_COMPILER=clang",
"-DCMAKE_CXX_COMPILER=clang++",
"-DCMAKE_LINKER=lld",
"-DCMAKE_EXE_LINKER_FLAGS=-fuse-ld=lld",
"-DCMAKE_MODULE_LINKER_FLAGS=-fuse-ld=lld",
"-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld"]
cmake_args += [
"-DCMAKE_C_COMPILER=clang",
"-DCMAKE_CXX_COMPILER=clang++",
"-DCMAKE_LINKER=lld",
"-DCMAKE_EXE_LINKER_FLAGS=-fuse-ld=lld",
"-DCMAKE_MODULE_LINKER_FLAGS=-fuse-ld=lld",
"-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld",
]
# Note that asan doesn't work with binaries that use the GPU, so this is
# only useful for tools like triton-opt that don't run code on the GPU.
#
# I tried and gave up getting msan to work. It seems that libstdc++'s
# std::string does not play nicely with clang's msan (I didn't try
# gcc's). I was unable to configure clang to ignore the error, and I
# also wasn't able to get libc++ to work, but that doesn't mean it's
# impossible. :)
if check_env_flag("TRITON_BUILD_WITH_ASAN"):
cmake_args += [
"-DCMAKE_C_FLAGS=-fsanitize=address",
"-DCMAKE_CXX_FLAGS=-fsanitize=address",
]
if check_env_flag("TRITON_BUILD_WITH_CCACHE"):
cmake_args += [
@@ -282,9 +311,27 @@ class CMakeBuild(build_ext):
subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir)
download_and_copy(src_path='bin/ptxas', version='12.1.105', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2")
download_and_copy(src_path='bin/cuobjdump', version='12.1.111', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2")
download_and_copy(src_path='bin/nvdisasm', version='12.1.105', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2")
download_and_copy(
src_path="bin/ptxas",
variable="TRITON_PTXAS_PATH",
version="12.1.105",
url_func=lambda arch, version:
f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2",
)
download_and_copy(
src_path="bin/cuobjdump",
variable="TRITON_CUOBJDUMP_PATH",
version="12.1.111",
url_func=lambda arch, version:
f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2",
)
download_and_copy(
src_path="bin/nvdisasm",
variable="TRITON_NVDISASM_PATH",
version="12.1.105",
url_func=lambda arch, version:
f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2",
)
setup(
name="triton",
@@ -307,10 +354,14 @@ setup(
"triton/third_party",
"triton/tools",
],
<<<<<<< HEAD
long_description_content_type="text/markdown",
install_requires=[
"filelock"
],
=======
install_requires=["filelock"],
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
include_package_data=True,
ext_modules=[CMakeExtension("triton", "triton/_C/")],
cmdclass={"build_ext": CMakeBuild, "build_py": CMakeBuildPy, "clean": CMakeClean},

View File

@@ -229,6 +229,12 @@ void init_triton_ir(py::module &&m) {
.value("RELAXED", mlir::triton::MemSemantic::RELAXED)
.export_values();
py::enum_<mlir::triton::MemSyncScope>(m, "MEM_SYNC_SCOPE", py::module_local())
.value("GPU", mlir::triton::MemSyncScope::GPU)
.value("CTA", mlir::triton::MemSyncScope::CTA)
.value("SYSTEM", mlir::triton::MemSyncScope::SYSTEM)
.export_values();
py::enum_<mlir::triton::EvictionPolicy>(m, "EVICTION_POLICY",
py::module_local())
.value("NORMAL", mlir::triton::EvictionPolicy::NORMAL)
@@ -1527,7 +1533,8 @@ void init_triton_ir(py::module &&m) {
// // atomic
.def("create_atomic_cas",
[](TritonOpBuilder &self, mlir::Value &ptr, mlir::Value &cmp,
mlir::Value &val, mlir::triton::MemSemantic sem) -> mlir::Value {
mlir::Value &val, mlir::triton::MemSemantic sem,
mlir::triton::MemSyncScope scope) -> mlir::Value {
mlir::Type dstType;
if (auto srcTensorType =
ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
@@ -1542,12 +1549,13 @@ void init_triton_ir(py::module &&m) {
dstType = ptrType.getPointeeType();
}
return self.create<mlir::triton::AtomicCASOp>(dstType, ptr, cmp,
val, sem);
val, sem, scope);
})
.def("create_atomic_rmw",
[](TritonOpBuilder &self, mlir::triton::RMWOp rmwOp,
mlir::Value &ptr, mlir::Value &val, mlir::Value &mask,
mlir::triton::MemSemantic sem) -> mlir::Value {
mlir::triton::MemSemantic sem,
mlir::triton::MemSyncScope scope) -> mlir::Value {
mlir::Type dstType;
if (auto srcTensorType =
ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
@@ -1561,8 +1569,8 @@ void init_triton_ir(py::module &&m) {
.cast<mlir::triton::PointerType>();
dstType = ptrType.getPointeeType();
}
return self.create<mlir::triton::AtomicRMWOp>(dstType, rmwOp, ptr,
val, mask, sem);
return self.create<mlir::triton::AtomicRMWOp>(
dstType, rmwOp, ptr, val, mask, sem, scope);
})
// External
.def("create_extern_elementwise",
@@ -1764,6 +1772,10 @@ void init_triton_ir(py::module &&m) {
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUCoalescePass());
})
.def("add_tritongpu_optimize_thread_locality_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUOptimizeThreadLocalityPass());
})
.def("add_symbol_dce_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createSymbolDCEPass());

View File

@@ -13,12 +13,11 @@ import torch
import triton
import triton.language as tl
from triton.common.backend import BaseBackend, register_backend
from triton.common.backend import (BaseBackend, compute_core_version_key, register_backend)
from triton.common.build import quiet
from triton.compiler.make_launcher import make_so_cache_key
from triton.runtime.cache import get_cache_manager
from triton.runtime.driver import DriverBase
from triton.runtime.jit import version_key
def build_for_backend(name, src, srcdir):
@@ -81,6 +80,7 @@ def build_for_backend(name, src, srcdir):
class ExtensionUtils:
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(ExtensionUtils, cls).__new__(cls)
@@ -110,6 +110,7 @@ class ExtensionUtils:
class ExtensionDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(ExtensionDriver, cls).__new__(cls)
@@ -125,6 +126,7 @@ class ExtensionBackend(BaseBackend):
def __init__(self, device_type: str) -> None:
super(ExtensionBackend, self).__init__(device_type)
self.driver = ExtensionDriver()
self.version_key = None
def add_stages(self, arch, extern_libs, stages):
filter_in_stages = ["ast", "ttir", "ttgir"]
@@ -163,9 +165,14 @@ class ExtensionBackend(BaseBackend):
def get_architecture_descriptor(self, **kwargs):
return ""
def get_version_key(self):
if self.version_key is None:
self.version_key = compute_core_version_key()
return self.version_key
def make_launcher_stub(self, name, signature, constants):
# name of files that are cached
so_cache_key = make_so_cache_key(version_key(), signature, constants)
so_cache_key = make_so_cache_key(self.get_version_key(), signature, constants)
so_cache_manager = get_cache_manager(so_cache_key)
so_name = f"{name}.so"
# retrieve stub from cache if it exists
@@ -250,13 +257,13 @@ def test_dummy_backend():
inp = torch.randn(10)
out = torch.randn(10)
kernel[(10,)](inp, out, 10, XBLOCK=16)
kernel[(10, )](inp, out, 10, XBLOCK=16)
spec = importlib.util.spec_from_file_location("__triton_launcher", ExtensionBackend.stub_so_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
launch_counter = getattr(mod, "launch_counter")
for _ in range(100):
kernel[(10,)](inp, out, 10, XBLOCK=16)
kernel[(10, )](inp, out, 10, XBLOCK=16)
assert launch_counter() > 0

View File

@@ -4,9 +4,7 @@ import pytest
def pytest_addoption(parser):
parser.addoption(
"--backend", action="store", default="", help="Codegen backend"
)
parser.addoption("--backend", action="store", default="", help="Codegen backend")
@pytest.fixture

View File

@@ -24,10 +24,10 @@ def test_xpu_backend(cmdopt):
if has_ipex:
for _ in range(1000):
x = torch.randn((65536,), device="xpu", dtype=torch.float32)
y = torch.randn((65536,), device="xpu", dtype=torch.float32)
z = torch.zeros((65536,), device="xpu", dtype=torch.float32)
kernel[(65536,)](x, y, z, num_warps=32)
x = torch.randn((65536, ), device="xpu", dtype=torch.float32)
y = torch.randn((65536, ), device="xpu", dtype=torch.float32)
z = torch.zeros((65536, ), device="xpu", dtype=torch.float32)
kernel[(65536, )](x, y, z, num_warps=32)
assert torch.all(x + y == z)
else:
return

View File

@@ -0,0 +1,100 @@
"""
issue: https://github.com/openai/triton/issues/2523
fused type convert and matmul, base on triton matmul, the different with matmul:
1. force C's dtype=dot_out_dtype to ["float16", "float32"]
2. accept A and B with dtype=["float32", "float64"]
"""
import pytest
import torch
import triton.language as tl
from triton import cdiv, jit
input_dtypes = ["float32", "float64"]
out_dtypes = ["float16", "float32"]
@pytest.mark.parametrize("M, K, N, w_dtype, x_dtype, out_dtype",
[(M, K, N, w, x, o) #
for (M, K, N) in [(128, 128, 128), (1280, 768, 1024)] #
for w in input_dtypes
for x in input_dtypes #
for o in out_dtypes])
def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype):
if x_dtype == w_dtype:
pytest.skip("skip same dtype")
device = torch.cuda.current_device()
x_dtype = getattr(torch, x_dtype)
w_dtype = getattr(torch, w_dtype)
a = torch.randn((M, K), device=device, dtype=x_dtype)
b = torch.randn((K, N), device=device, dtype=w_dtype)
torch_dtype = getattr(torch, out_dtype)
triton_dtype = getattr(tl, out_dtype) # <- here force dot_out_dtype
out_torch = torch.matmul(a.to(torch_dtype), b.to(torch_dtype))
out_triton = torch.empty((M, N), device=device, dtype=torch_dtype)
allow_tf32 = True
# launch kernel
BLOCK_M, BLOCK_N, BLOCK_K = 16, 16, 32
grid = ((cdiv(M, BLOCK_M) * cdiv(N, BLOCK_N)), 1)
@jit
def matmul_kernel(A, B, C, M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
dot_out_dtype: tl.constexpr, #
allow_tf32: tl.constexpr, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, #
BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr):
# matrix multiplication
pid = tl.program_id(0)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
for k in range(0, tl.cdiv(K, BLOCK_K)):
k_remaining = K - k * BLOCK_K
_0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
a = a.to(C.dtype.element_ty)
b = b.to(C.dtype.element_ty)
acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
acc = acc.to(C.dtype.element_ty)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
tl.store(C, acc, mask=mask)
matmul_kernel[grid](
a, b, out_triton, M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
out_triton.stride(0), out_triton.stride(1), dot_out_dtype=triton_dtype, #
allow_tf32=allow_tf32, #
GROUP_M=8, #
BLOCK_M=BLOCK_M, #
BLOCK_N=BLOCK_N, #
BLOCK_K=BLOCK_K)
torch.testing.assert_close(out_torch, out_triton, atol=0.3, rtol=0.01)

View File

@@ -14,18 +14,14 @@ def test_chained_matmul():
return torch.einsum('MN,NK->MK', intermediate, c)
@triton.jit
def chained_matmul_kernel(
A, # shape: (m, k)
B, # shape: (n, k)
C, # shape: (n, k)
out, # shape: (m, k)
m, n, k: tl.constexpr,
block_m: tl.constexpr,
block_n: tl.constexpr,
block_k: tl.constexpr):
def chained_matmul_kernel(A, # shape: (m, k)
B, # shape: (n, k)
C, # shape: (n, k)
out, # shape: (m, k)
m, n, k: tl.constexpr, #
block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr):
tl.static_assert(block_k == k,
f"expected block_k == k but got {block_k} != {k}")
tl.static_assert(block_k == k, f"expected block_k == k but got {block_k} != {k}")
block_ix = tl.program_id(0)
a_tile = (block_ix * block_m + tl.arange(0, block_m))[:, None] * block_k \
@@ -55,35 +51,33 @@ def test_chained_matmul():
m, n, k = 32, 64, 128
block_m, block_n, block_k = 16, 32, k
grid = (triton.cdiv(m, block_m),)
a = torch.randint(low=0, high=2, size=(m, k), dtype=torch.float16,
device='cuda')
b = torch.randint(low=0, high=2, size=(n, k), dtype=torch.float16,
device='cuda')
grid = (triton.cdiv(m, block_m), )
a = torch.randint(low=0, high=2, size=(m, k), dtype=torch.float16, device='cuda')
b = torch.randint(low=0, high=2, size=(n, k), dtype=torch.float16, device='cuda')
c = torch.randint_like(b, low=0, high=2)
triton_result = torch.zeros_like(a)
torch_result = chained_matmul_reference(a, b, c)
chained_matmul_kernel[grid](a, b, c, triton_result, m, n, k,
block_m=block_m, block_n=block_n,
block_k=block_k)
chained_matmul_kernel[grid](
a, b, c, triton_result, m, n, k, #
block_m=block_m, block_n=block_n, block_k=block_k)
assert (torch_result == triton_result).all()
def test_vecmat():
@triton.jit
def batched_vecmat(
# inputs
A, # shape: [dim_m, dim_k]
B, # shape: [dim_m, dim_n, dim_k]
# dimensions
# inputs
A, # shape: [dim_m, dim_k]
B, # shape: [dim_m, dim_n, dim_k]
# dimensions
dim_m, dim_n, dim_k,
# outputs
output,
# block information
block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr
):
# outputs
output,
# block information
block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr):
m_index = tl.program_id(0)
n_index = tl.program_id(1)
# Output tile
@@ -125,9 +119,10 @@ def test_vecmat():
grid = (M // block_m, N // block_n)
batched_vecmat[grid](A_tri, B_tri, M, N, K, C_tri,
block_m=block_m, block_n=block_n, block_k=block_k,
num_warps=4, num_stages=1)
batched_vecmat[grid](
A_tri, B_tri, M, N, K, C_tri, #
block_m=block_m, block_n=block_n, block_k=block_k, #
num_warps=4, num_stages=1)
A_expanded = A[:, np.newaxis, :]
A_broadcasted = np.broadcast_to(A_expanded, (M, N, K))
@@ -137,18 +132,18 @@ def test_vecmat():
np.testing.assert_allclose(C_ref, C_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
@pytest.mark.parametrize("type", ["pre_load", "post_load", "post_pre_mixed", "post_load_two_iters", "post_load_three_iters"])
@pytest.mark.parametrize("type",
["pre_load", "post_load", "post_pre_mixed", "post_load_two_iters", "post_load_three_iters"])
def test_iv_dependent_matmul(type):
@triton.jit
def kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
type: tl.constexpr
):
def kernel(a_ptr, b_ptr, c_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
type: tl.constexpr):
pid = tl.program_id(axis=0)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
pid_m = pid // num_pid_n
@@ -216,15 +211,16 @@ def test_iv_dependent_matmul(type):
b = torch.rand((K, N), device='cuda')
torch_output = torch.mm(a, b)
triton_output = torch.empty_like(
torch_output, device=torch_output.device)
triton_output = torch.empty_like(torch_output, device=torch_output.device)
def grid(META):
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
num_stages = 4 if type == "post_load_three_iters" else 3
kernel[grid](a, b, triton_output, M, N, K, a.stride(0), a.stride(1),
b.stride(0), b.stride(1), triton_output.stride(0), triton_output.stride(1),
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
type=type, num_stages=num_stages)
kernel[grid](
a, b, triton_output, M, N, K, #
a.stride(0), a.stride(1), b.stride(0), b.stride(1), #
triton_output.stride(0), triton_output.stride(1), #
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, type=type, #
num_stages=num_stages)
torch.testing.assert_close(torch_output, triton_output, rtol=1e-2, atol=1e-2)

View File

@@ -26,7 +26,6 @@ sm_clocks = {'v100': 1350, 'a100': 1350}
mem_clocks = {'v100': 877, 'a100': 1215}
matmul_data = {
# NOTE:
'a100': {
# square
(512, 512, 512): {'float16': 0.108, 'float32': 0.097, 'int8': 0.05},
@@ -49,10 +48,9 @@ matmul_data = {
}
@pytest.mark.parametrize('M, N, K, dtype_str',
[(M, N, K, dtype_str)
for M, N, K in matmul_data[DEVICE_NAME].keys()
for dtype_str in ['float16']])
@pytest.mark.parametrize('M, N, K, dtype_str', [(M, N, K, dtype_str)
for M, N, K in matmul_data[DEVICE_NAME].keys()
for dtype_str in ['float16']])
def test_matmul(M, N, K, dtype_str):
stream = torch.cuda.Stream()
torch.cuda.set_stream(stream)
@@ -86,8 +84,7 @@ def test_matmul(M, N, K, dtype_str):
@triton.jit
def _add(x_ptr, y_ptr, output_ptr, n_elements,
BLOCK_SIZE: tl.constexpr):
def _add(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
@@ -136,36 +133,36 @@ def test_elementwise(N, dtype_str):
print_perf(ms, cur_gpu_util, ref_gpu_util)
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.02, rtol=0.01)
#######################
# Flash-Attention
#######################
flash_attention_data = {
"a100": {
(4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.542,
(4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.471,
(4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.155,
(4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.203,
(4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.202,
(4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.108,
(4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.232,
(4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.231,
(4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.138,
(4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.306,
(4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.266,
(4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.098,
(4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.134,
(4, 48, 4096, 64, True, False, 'backward', 'bfloat16'): 0.135,
(4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.066,
(4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.092,
(4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.541,
(4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.471,
(4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.150,
(4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.263,
(4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.291,
(4, 48, 4096, 64, False, True, 'backward', 'bfloat16'): 0.255,
(4, 48, 1024, 16, False, True, 'backward', 'float32'): 0.144,
(4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.306,
(4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.266,
(4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.098,
(4, 48, 4096, 64, False, False, 'backward', 'float16'): 0.159,
(4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.136,
(4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.159,
(4, 48, 1024, 16, False, False, 'backward', 'float32'): 0.088,
}
}
@@ -221,8 +218,7 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str):
@triton.jit
def _sum(x_ptr, y_ptr, output_ptr, n_elements,
BLOCK_SIZE: tl.constexpr):
def _sum(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
@@ -260,8 +256,8 @@ def test_reductions(N, dtype_str):
y = torch.randn_like(z)
else:
info = torch.iinfo(dtype)
x = torch.randint(info.min, info.max, (N,), dtype=dtype, device='cuda')
y = torch.randint(info.min, info.max, (N,), dtype=dtype, device='cuda')
x = torch.randint(info.min, info.max, (N, ), dtype=dtype, device='cuda')
y = torch.randint(info.min, info.max, (N, ), dtype=dtype, device='cuda')
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
fn = lambda: _sum[grid](x, y, z, N, BLOCK_SIZE=1024)
ms = triton.testing.do_bench_cudagraph(fn)

View File

@@ -9,6 +9,7 @@ import yaml
class ComparisonResult:
def __init__(self, name: str, numComparisons: int, diffs: List[str] = None, errors: List[str] = None):
self.name = name
self.numComparisons = numComparisons
@@ -142,7 +143,8 @@ def doFilesMatch(path1: str, path2: str) -> bool:
return True
def compareMatchingFiles(name: str, nameToHashes1: Dict[str, List[str]], nameToHashes2: Dict[str, List[str]], args) -> ComparisonResult:
def compareMatchingFiles(name: str, nameToHashes1: Dict[str, List[str]], nameToHashes2: Dict[str, List[str]],
args) -> ComparisonResult:
"""
Compare files with the given name in all hashes in both paths
Return the first mismatching files as a tuple (file1, file2), otherwise, return an empty tuple

View File

@@ -18,7 +18,6 @@
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
Fused Attention
===============
@@ -35,18 +34,15 @@ import triton.language as tl
@triton.jit
def _fwd_kernel(
Q, K, V, sm_scale,
L, M,
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
Z, H, N_CTX, D0,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
def _fwd_kernel(Q, K, V, sm_scale, #
L, M, #
Out, #
stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vk, stride_vn, #
stride_oz, stride_oh, stride_om, stride_on, #
Z, H, N_CTX, D0, #
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
@@ -61,31 +57,38 @@ def _fwd_kernel(
stride_qh_2d = stride_qh // stride_qm // stride_qk
q_tile_ptr = tl.make_block_ptr(base=Q,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(
off_hz * stride_qh_2d + start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
k_tile_ptr = tl.make_block_ptr(base=K,
shape=(D0, BLOCK_DMODEL),
strides=(stride_kn, stride_kk),
offsets=(off_hz * stride_qh_2d, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0))
v_tile_ptr = tl.make_block_ptr(base=V,
shape=(D0, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(off_hz * stride_qh_2d, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0))
out_tile_ptr = tl.make_block_ptr(base=Out,
shape=(D0, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
q_tile_ptr = tl.make_block_ptr(
base=Q,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
k_tile_ptr = tl.make_block_ptr(
base=K,
shape=(D0, BLOCK_DMODEL),
strides=(stride_kn, stride_kk),
offsets=(off_hz * stride_qh_2d, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0),
)
v_tile_ptr = tl.make_block_ptr(
base=V,
shape=(D0, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(off_hz * stride_qh_2d, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0),
)
out_tile_ptr = tl.make_block_ptr(
base=Out,
shape=(D0, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
# load q: it will stay in SRAM throughout
q = tl.load(q_tile_ptr)
@@ -96,8 +99,7 @@ def _fwd_kernel(
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (
start_n + offs_n[None, :]), qk, float("-inf"))
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
# compute new m
m_curr = tl.maximum(tl.max(qk, 1), m_prev)
# correct old l
@@ -133,11 +135,9 @@ def _fwd_kernel(
@triton.jit
def _bwd_preprocess(
Out, DO, L,
NewDO, Delta,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
):
def _bwd_preprocess(Out, DO, L, #
NewDO, Delta, #
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD)
# load
@@ -153,19 +153,14 @@ def _bwd_preprocess(
@triton.jit
def _bwd_kernel(
Q, K, V, sm_scale, Out, DO,
DQ, DK, DV,
L, M,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX, D0,
num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
def _bwd_kernel(Q, K, V, sm_scale, Out, DO, #
DQ, DK, DV, #
L, M, #
D, stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vk, stride_vn, #
Z, H, N_CTX, D0, #
num_block, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr):
off_hz = tl.program_id(0)
off_z = off_hz // H
off_h = off_hz % H
@@ -173,55 +168,62 @@ def _bwd_kernel(
stride_qz_2d = stride_qz // stride_qm // stride_qk
stride_qh_2d = stride_qh // stride_qm // stride_qk
q_tile_ptr = tl.make_block_ptr(base=Q,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
k_tile_ptr = tl.make_block_ptr(base=K,
shape=(D0, BLOCK_DMODEL),
strides=(stride_kn, stride_kk),
offsets=(
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
v_tile_ptr = tl.make_block_ptr(base=V,
shape=(D0, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
do_tile_ptr = tl.make_block_ptr(base=DO,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
dq_tile_ptr = tl.make_block_ptr(base=DQ,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
dk_tile_ptr = tl.make_block_ptr(base=DK,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
dv_tile_ptr = tl.make_block_ptr(base=DV,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
q_tile_ptr = tl.make_block_ptr(
base=Q,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
k_tile_ptr = tl.make_block_ptr(
base=K,
shape=(D0, BLOCK_DMODEL),
strides=(stride_kn, stride_kk),
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
v_tile_ptr = tl.make_block_ptr(
base=V,
shape=(D0, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
do_tile_ptr = tl.make_block_ptr(
base=DO,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
dq_tile_ptr = tl.make_block_ptr(
base=DQ,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
dk_tile_ptr = tl.make_block_ptr(
base=DK,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
dv_tile_ptr = tl.make_block_ptr(
base=DV,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
# offset pointers for batch/head
DQ += off_z * stride_qz + off_h * stride_qh
for start_n in range(0, num_block):
@@ -250,8 +252,7 @@ def _bwd_kernel(
# recompute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here
qk = tl.dot(q, tl.trans(k))
qk = tl.where(offs_m_curr[:, None] >= (
offs_n[None, :]), qk, float("-inf"))
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
m = tl.load(m_ptrs + offs_m_curr)
p = tl.exp(qk * sm_scale - m[:, None])
# compute dv
@@ -301,29 +302,21 @@ class _attention(torch.autograd.Function):
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
L = torch.empty(
(q.shape[0] * q.shape[1], q.shape[2]),
device=q.device,
dtype=torch.float32)
m = torch.empty(
(q.shape[0] * q.shape[1], q.shape[2]),
device=q.device,
dtype=torch.float32)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
D0 = q.shape[0] * q.shape[1] * q.shape[2]
_fwd_kernel[grid](
q, k, v, sm_scale,
L, m,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2], D0,
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk, num_warps=num_warps,
num_stages=2,
)
q, k, v, sm_scale, #
L, m, #
o, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
q.shape[0], q.shape[1], q.shape[2], D0, #
BLOCK_M=BLOCK, BLOCK_N=BLOCK, BLOCK_DMODEL=Lk, #
num_warps=num_warps, num_stages=2)
ctx.save_for_backward(q, k, v, o, L, m)
ctx.grid = grid
@@ -343,25 +336,22 @@ class _attention(torch.autograd.Function):
delta = torch.empty_like(l)
D0 = q.shape[0] * q.shape[1] * q.shape[2]
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do, l,
do_scaled, delta,
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
l, m,
delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2], D0,
ctx.grid[0],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
num_stages=1,
)
o, do, l, #
do_scaled, delta, #
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL)
_bwd_kernel[(ctx.grid[1], )](
q, k, v, ctx.sm_scale, #
o, do_scaled, #
dq, dk, dv, #
l, m, #
delta, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
q.shape[0], q.shape[1], q.shape[2], D0, #
ctx.grid[0], #
BLOCK_M=BLOCK, BLOCK_N=BLOCK, BLOCK_DMODEL=ctx.BLOCK_DMODEL, #
num_warps=8, num_stages=1)
return dq, dk, dv, None
@@ -380,15 +370,9 @@ attention = _attention.apply
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires arch 9+")
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
q = torch.empty(
(Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(
mean=0.1, std=0.2).requires_grad_()
k = torch.empty(
(Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(
mean=0.4, std=0.2).requires_grad_()
v = torch.empty(
(Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(
mean=0.3, std=0.2).requires_grad_()
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
sm_scale = 0.2
dout = torch.randn_like(q)
# reference implementation
@@ -427,22 +411,25 @@ except BaseException:
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4
configs = [triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=[2**i for i in range(10, 14)],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
args={
'H': N_HEADS,
'BATCH': BATCH,
'D_HEAD': D_HEAD,
'dtype': torch.float16,
'mode': mode}
) for mode in ['fwd', 'bwd']]
configs = [
triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=[2**i for i in range(10, 14)],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
args={
'H': N_HEADS,
'BATCH': BATCH,
'D_HEAD': D_HEAD,
'dtype': torch.float16,
'mode': mode,
},
) for mode in ['fwd', 'bwd']
]
@triton.testing.perf_report(configs)
@@ -463,9 +450,8 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
if provider == "flash":
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
cu_seqlens = torch.zeros(
(BATCH + 1,), device=device, dtype=torch.int32)
lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device)
cu_seqlens = torch.zeros((BATCH + 1, ), device=device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0)
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)

View File

@@ -32,19 +32,30 @@ import triton.language as tl
@triton.jit
def matmul_no_scf_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr
):
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, 0), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
def matmul_no_scf_kernel(a_ptr, b_ptr, c_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr #
):
a_block_ptr = tl.make_block_ptr(
base=a_ptr,
shape=(M, K),
strides=(stride_am, stride_ak),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_K),
order=(1, 0),
)
b_block_ptr = tl.make_block_ptr(
base=b_ptr,
shape=(K, N),
strides=(stride_bk, stride_bn),
offsets=(0, 0),
block_shape=(BLOCK_K, BLOCK_N),
order=(0, 1),
)
a = tl.load(a_block_ptr)
b = tl.load(b_block_ptr)
@@ -54,8 +65,8 @@ def matmul_no_scf_kernel(
c = c.to(tl.float16)
if USE_TMA_EPILOGUE:
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
tl.store(c_block_ptr, c)
else:
offs_m = tl.arange(0, BLOCK_M)
@@ -64,33 +75,30 @@ def matmul_no_scf_kernel(
tl.store(c_ptrs, c)
@pytest.mark.parametrize('M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,ENABLE_WS',
itertools.chain(
*[
[
# numCTAs = 1, no TMA multicast:
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
# static mask, cluster 4x1
[256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
# dynamic mask, cluster 2x2
[128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
# small M, N
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
] for USE_TMA_EPILOGUE in [True, False]
for ENABLE_WS in [False, True]
]))
@pytest.mark.parametrize(
'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,ENABLE_WS',
itertools.chain(*[[
# numCTAs = 1, no TMA multicast:
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
# static mask, cluster 4x1
[256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
# dynamic mask, cluster 2x2
[128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
# small M, N
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
] for USE_TMA_EPILOGUE in [True, False] for ENABLE_WS in [False, True]]))
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE, ENABLE_WS):
if (TRANS_A):
@@ -107,46 +115,41 @@ def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE
else:
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
matmul_no_scf_kernel[(1, 1)](a_ptr=a, b_ptr=b, c_ptr=c,
M=M, N=N, K=K,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1),
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K,
num_warps=NUM_WARPS,
num_ctas=NUM_CTAS,
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"),
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE,
enable_warp_specialization=ENABLE_WS)
matmul_no_scf_kernel[(1, 1)](
a_ptr=a, b_ptr=b, c_ptr=c, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_cm=c.stride(0), stride_cn=c.stride(1), #
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, #
num_warps=NUM_WARPS, #
num_ctas=NUM_CTAS, #
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), #
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE, #
enable_warp_specialization=ENABLE_WS)
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
golden = torch.matmul(a_f32, b_f32)
torch.set_printoptions(profile="full")
assert_close(
c,
golden,
rtol=1e-2,
atol=1e-3,
check_dtype=False)
assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_wm, stride_wn,
stride_zm, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr,
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr,
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr,
W_ORDER_0: tl.constexpr, W_ORDER_1: tl.constexpr,
Z_ORDER_0: tl.constexpr, Z_ORDER_1: tl.constexpr
):
def matmul_kernel(a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_wm, stride_wn, #
stride_zm, stride_zn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, #
out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr, #
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, #
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, #
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, #
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr, #
W_ORDER_0: tl.constexpr, W_ORDER_1: tl.constexpr, #
Z_ORDER_0: tl.constexpr, Z_ORDER_1: tl.constexpr #
):
pid = tl.program_id(axis=0)
num_pid_n = tl.cdiv(N, BLOCK_N)
num_pid_m = tl.cdiv(M, BLOCK_M)
@@ -159,13 +162,31 @@ def matmul_kernel(
block_offset_m = pid_m * BLOCK_M
block_offset_n = pid_n * BLOCK_N
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(A_ORDER_0, A_ORDER_1))
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(B_ORDER_0, B_ORDER_1))
a_tile_ptr = tl.make_block_ptr(
base=a_ptr,
shape=(M, K),
strides=(stride_am, stride_ak),
offsets=(block_offset_m, 0),
block_shape=(BLOCK_M, BLOCK_K),
order=(A_ORDER_0, A_ORDER_1),
)
b_tile_ptr = tl.make_block_ptr(
base=b_ptr,
shape=(K, N),
strides=(stride_bk, stride_bn),
offsets=(0, block_offset_n),
block_shape=(BLOCK_K, BLOCK_N),
order=(B_ORDER_0, B_ORDER_1),
)
# for chain-dot, BLOCK_N must always be equal to N, and each program loads the whole W matrix
w_tile_ptr = tl.make_block_ptr(base=w_ptr, shape=(N, N), strides=(stride_wm, stride_wn),
offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_N), order=(W_ORDER_0, W_ORDER_1))
w_tile_ptr = tl.make_block_ptr(
base=w_ptr,
shape=(N, N),
strides=(stride_wm, stride_wn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_N),
order=(W_ORDER_0, W_ORDER_1),
)
z = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
offs_m = block_offset_m + tl.arange(0, BLOCK_M)
@@ -204,141 +225,151 @@ def matmul_kernel(
if USE_TMA_STORE:
z_block_ptr = tl.make_block_ptr(base=z_ptr, shape=(M, N), strides=(stride_zm, stride_zn),
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(Z_ORDER_0, Z_ORDER_1))
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N),
order=(Z_ORDER_0, Z_ORDER_1))
tl.store(z_block_ptr, z, boundary_check=(0, 1))
else:
tl.store(z_ptrs, z, mask=mask)
@pytest.mark.parametrize('BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for enable_ws in [False, True]
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
]
for epilogue in ['softmax']
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for trans_a in [False,]
for trans_b in [True,]
for trans_output in [False,]
for num_stages in [3]
for enable_ws in [False, True]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
]
for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot']
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for trans_a in [False,]
for trans_b in [True,]
for trans_output in [False,]
for num_stages in [3]
for enable_ws in [False, True]
if not (epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
]
for out_dtype in ['float32',]
for use_tma_store in [False,]
for trans_a in [False, True]
for trans_b in [False, True]
for trans_output in [False, True]
for num_stages in [3]
for enable_ws in [False, True]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
for n in [16, 32, 64, 128, 256]
for trans_output in [False,]
for out_dtype in ['float32',]
for use_tma_store in [False,]
for num_stages in [2, 4, 5, 7]
for enable_ws in [False, True]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
]
for shape in [
[512, 360, 1024],
[360, 4096, 512],
]
for trans_output in [False,]
for out_dtype in ['float32',]
for use_tma_store in [False, True]
for num_stages in [3, 4]
for enable_ws in [False, True]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()
[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS):
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
for enable_ws in [False, True]
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
]
for epilogue in ['softmax']
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for trans_a in [False]
for trans_b in [True]
for trans_output in [False]
for num_stages in [3]
for enable_ws in [False, True]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64]
for num_warps in [4, 8]
for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256]
for num_warps in [4, 8]
for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
]
for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot']
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for trans_a in [False]
for trans_b in [True]
for trans_output in [False]
for num_stages in [3]
for enable_ws in [False, True]
if not (epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
]
for out_dtype in ['float32']
for use_tma_store in [False]
for trans_a in [False, True]
for trans_b in [False, True]
for trans_output in [False, True]
for num_stages in [3]
for enable_ws in [False, True]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages,
enable_ws)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
for enable_ws in [False, True]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
]
for shape in [
[512, 360, 1024],
[360, 4096, 512],
]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False, True]
for num_stages in [3, 4]
for enable_ws in [False, True]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower()
if NUM_CTAS > 1 and enable_tma in ["on", "true", "1"]:
pytest.skip('multi-CTA with TMA not supported in MaterializeLoadStore')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
@@ -410,38 +441,38 @@ def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A,
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
pgm = matmul_kernel[grid](a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z,
M=M, N=N, K=K,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_wm=w.stride(0), stride_wn=w.stride(1),
stride_zm=z.stride(0), stride_zn=z.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8,
out_dtype=out_dtype,
USE_TMA_STORE=USE_TMA_STORE,
ADD_MATRIX=epilogue == 'add-matrix',
ADD_ROWS=epilogue == 'add-rows',
ADD_COLS=epilogue == 'add-cols',
DO_SOFTMAX=epilogue == 'softmax',
CHAIN_DOT=epilogue == 'chain-dot',
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1],
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1],
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1],
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1],
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES,
enable_warp_specialization=ENABLE_WS)
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES, #
enable_warp_specialization=ENABLE_WS)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden,
rtol=1e-2,
atol=1e-3,
check_dtype=False)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:

View File

@@ -27,16 +27,20 @@ import triton.language as tl
@triton.jit
def gemm_fusion_kernel(A, B, C, E,
M, N, K,
stride_am, stride_ak, stride_bn, stride_bk, stride_cn, stride_ck, stride_em, stride_ek,
def gemm_fusion_kernel(A, B, C, E, #
M, N, K, #
stride_am, stride_ak, stride_bn, stride_bk, stride_cn, stride_ck, stride_em, stride_ek, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
pid = tl.program_id(0)
a_tile_ptr = tl.make_block_ptr(base=A, shape=(M, K), strides=(stride_am, stride_ak), offsets=(pid * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_tile_ptr = tl.make_block_ptr(base=B, shape=(N, K), strides=(stride_bn, stride_bk), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_K), order=(1, 0))
c_tile_ptr = tl.make_block_ptr(base=C, shape=(N, K), strides=(stride_cn, stride_ck), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_K), order=(1, 0))
e_tile_ptr = tl.make_block_ptr(base=E, shape=(M, K), strides=(stride_em, stride_ek), offsets=(pid * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
a_tile_ptr = tl.make_block_ptr(base=A, shape=(M, K), strides=(stride_am, stride_ak), offsets=(pid * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_tile_ptr = tl.make_block_ptr(base=B, shape=(N, K), strides=(stride_bn, stride_bk), offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_K), order=(1, 0))
c_tile_ptr = tl.make_block_ptr(base=C, shape=(N, K), strides=(stride_cn, stride_ck), offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_K), order=(1, 0))
e_tile_ptr = tl.make_block_ptr(base=E, shape=(M, K), strides=(stride_em, stride_ek), offsets=(pid * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
acc_e = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
a = tl.load(a_tile_ptr)
@@ -57,66 +61,70 @@ def gemm_fusion_kernel(A, B, C, E,
def test_gemm_fusion():
M, N, K = 4096, 4096, 64
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 64
A = torch.empty(
(M, K), dtype=torch.float16, device='cuda').normal_(
mean=0.1, std=0.2)
B = torch.empty(
(N, K), dtype=torch.float16, device='cuda').normal_(
mean=0.1, std=0.2)
C = torch.empty(
(N, K), dtype=torch.float16, device='cuda').normal_(
mean=0.1, std=0.2)
A = torch.empty((M, K), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2)
B = torch.empty((N, K), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2)
C = torch.empty((N, K), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2)
E = torch.empty((M, K), dtype=torch.float16, device='cuda')
ref_out = torch.matmul(torch.matmul(A, B.T), C)
num_warps = 4
grid = (triton.cdiv(M, BLOCK_M), 1)
gemm_fusion_kernel[grid](A, B, C, E, M, N, K,
A.stride(0), A.stride(1), B.stride(0), B.stride(
1), C.stride(0), C.stride(1), E.stride(0), E.stride(1),
BLOCK_M, BLOCK_N, BLOCK_K, num_warps=num_warps)
gemm_fusion_kernel[grid](
A, B, C, E, M, N, K, #
A.stride(0), A.stride(1), #
B.stride(0), B.stride(1), #
C.stride(0), C.stride(1), #
E.stride(0), E.stride(1), #
BLOCK_M, BLOCK_N, BLOCK_K, #
num_warps=num_warps)
torch.testing.assert_close(ref_out, E, atol=1e-2, rtol=0)
@triton.jit
def batched_gemm_fusion(
Q, K, V, Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
Z, NH, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
def batched_gemm_fusion(Q, K, V, Out, #
stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vk, stride_vn, #
stride_oz, stride_oh, stride_om, stride_on, #
Z, NH, N_CTX, #
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
BLOCK_N: tl.constexpr):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
q_tile_ptr = tl.make_block_ptr(base=Q,
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
strides=(stride_qz, stride_qh, stride_qm, stride_qk),
offsets=(off_hz // NH, off_hz % NH, start_m, 0),
block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL),
order=(3, 2, 1, 0))
k_tile_ptr = tl.make_block_ptr(base=K,
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
strides=(stride_kz, stride_kh, stride_kn, stride_kk),
offsets=(off_hz // NH, off_hz % NH, 0, 0),
block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL),
order=(3, 2, 1, 0))
v_tile_ptr = tl.make_block_ptr(base=V,
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
strides=(stride_vz, stride_vh, stride_vk, stride_vn),
offsets=(off_hz // NH, off_hz % NH, 0, 0),
block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL),
order=(3, 2, 1, 0))
o_tile_ptr = tl.make_block_ptr(base=Out,
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
strides=(stride_oz, stride_oh, stride_om, stride_on),
offsets=(off_hz // NH, off_hz % NH, start_m, 0),
block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL),
order=(3, 2, 1, 0))
q_tile_ptr = tl.make_block_ptr(
base=Q,
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
strides=(stride_qz, stride_qh, stride_qm, stride_qk),
offsets=(off_hz // NH, off_hz % NH, start_m, 0),
block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL),
order=(3, 2, 1, 0),
)
k_tile_ptr = tl.make_block_ptr(
base=K,
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
strides=(stride_kz, stride_kh, stride_kn, stride_kk),
offsets=(off_hz // NH, off_hz % NH, 0, 0),
block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL),
order=(3, 2, 1, 0),
)
v_tile_ptr = tl.make_block_ptr(
base=V,
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
strides=(stride_vz, stride_vh, stride_vk, stride_vn),
offsets=(off_hz // NH, off_hz % NH, 0, 0),
block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL),
order=(3, 2, 1, 0),
)
o_tile_ptr = tl.make_block_ptr(
base=Out,
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
strides=(stride_oz, stride_oh, stride_om, stride_on),
offsets=(off_hz // NH, off_hz % NH, start_m, 0),
block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL),
order=(3, 2, 1, 0),
)
q = tl.load(q_tile_ptr, boundary_check=(0, 1, 2, 3))
q = tl.view(q, (BLOCK_M, BLOCK_DMODEL))
@@ -155,12 +163,13 @@ def test_batched_gemm_fusion():
ref_out = torch.matmul(torch.matmul(A, BT), C)
num_warps = 4
grid = (triton.cdiv(N_CTX, BLOCK_M), B * NH)
batched_gemm_fusion[grid](A, B, C, E,
A.stride(0), A.stride(1), A.stride(2), A.stride(3),
B.stride(0), B.stride(1), B.stride(2), B.stride(3),
C.stride(0), C.stride(1), C.stride(2), C.stride(3),
E.stride(0), E.stride(1), E.stride(2), E.stride(3),
Z, NH, N_CTX,
BLOCK_M, BLOCK_DMODEL, BLOCK_N, num_warps=num_warps)
batched_gemm_fusion[grid](
A, B, C, E, #
A.stride(0), A.stride(1), A.stride(2), A.stride(3), #
B.stride(0), B.stride(1), B.stride(2), B.stride(3), #
C.stride(0), C.stride(1), C.stride(2), C.stride(3), #
E.stride(0), E.stride(1), E.stride(2), E.stride(3), #
Z, NH, N_CTX, #
BLOCK_M, BLOCK_DMODEL, BLOCK_N, num_warps=num_warps)
torch.testing.assert_close(ref_out, E, atol=1e-2, rtol=0)

View File

@@ -24,10 +24,8 @@ def add_kernel(
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x_block_ptr = tl.make_block_ptr(
base=x_ptr, shape=(n_elements, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),
block_shape=(BLOCK_SIZE, ), order=(0, )
)
x_block_ptr = tl.make_block_ptr(base=x_ptr, shape=(n_elements, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),
block_shape=(BLOCK_SIZE, ), order=(0, ))
x = tl.load(x_block_ptr, boundary_check=(0, ), padding_option='zero')
y = tl.load(y_ptr + offsets, mask=mask)
@@ -36,9 +34,7 @@ def add_kernel(
@pytest.mark.parametrize('SIZE,BLOCK_SIZE,dtype_str',
[(98432, 1024, dtype_str)
for dtype_str in ['float16', 'float32']
])
[(98432, 1024, dtype_str) for dtype_str in ['float16', 'float32']])
def test_add(SIZE, BLOCK_SIZE, dtype_str):
dtype = dtype_mapping[dtype_str]
output = torch.empty(SIZE, device='cuda', dtype=dtype)
@@ -46,7 +42,8 @@ def test_add(SIZE, BLOCK_SIZE, dtype_str):
y = torch.randn(SIZE, device='cuda', dtype=dtype)
def grid(meta):
return (triton.cdiv(SIZE, meta['BLOCK_SIZE']),)
return (triton.cdiv(SIZE, meta['BLOCK_SIZE']), )
add_kernel[grid](x, y, output, SIZE, BLOCK_SIZE=BLOCK_SIZE)
output_torch = x + y
@@ -64,25 +61,20 @@ def load_reduce_kernel(
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
x_ptr = tl.make_block_ptr(
base=x_ptr, shape=(BLOCK_M, BLOCK_N), strides=(stride_xm, stride_xn),
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)
)
x_ptr = tl.make_block_ptr(base=x_ptr, shape=(BLOCK_M, BLOCK_N), strides=(stride_xm, stride_xn), offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
x = tl.load(x_ptr)
y = tl.max(x, axis=1)
tl.store(y_ptr + tl.arange(0, BLOCK_M), y)
@pytest.mark.parametrize('BLOCK_M,BLOCK_N,dtype_str',
[(128, 64, dtype_str)
for dtype_str in ['float16']
])
@pytest.mark.parametrize('BLOCK_M,BLOCK_N,dtype_str', [(128, 64, dtype_str) for dtype_str in ['float16']])
def test_load_reduce(BLOCK_M, BLOCK_N, dtype_str):
dtype = dtype_mapping[dtype_str]
x = torch.randn((BLOCK_M, BLOCK_N), device='cuda', dtype=dtype)
y = torch.empty((BLOCK_M, ), device='cuda', dtype=dtype)
load_reduce_kernel[(1,)](x, y, x.stride(0), x.stride(1), y.stride(0), BLOCK_M, BLOCK_N)
load_reduce_kernel[(1, )](x, y, x.stride(0), x.stride(1), y.stride(0), BLOCK_M, BLOCK_N)
golden = x.max(dim=1)[0]
torch.set_printoptions(profile='full')

View File

@@ -18,7 +18,6 @@
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
Fused Attention
===============
@@ -40,18 +39,17 @@ import triton.language as tl
key=['Q', 'K', 'V'],
)
@triton.jit
def _fwd_kernel(
Q, K, V, sm_scale,
L, M,
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
def _fwd_kernel(Q, K, V, sm_scale, #
L, M, #
Out, #
stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vk, stride_vn, #
stride_oz, stride_oh, stride_om, stride_on, #
Z, H, N_CTX, #
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
BLOCK_N: tl.constexpr #
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
# initialize offsets
@@ -116,11 +114,10 @@ def _fwd_kernel(
@triton.jit
def _bwd_preprocess(
Out, DO, L,
NewDO, Delta,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
):
def _bwd_preprocess(Out, DO, L, #
NewDO, Delta, #
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr #
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD)
# load
@@ -136,19 +133,18 @@ def _bwd_preprocess(
@triton.jit
def _bwd_kernel(
Q, K, V, sm_scale, Out, DO,
DQ, DK, DV,
L, M,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
def _bwd_kernel(Q, K, V, sm_scale, Out, DO, #
DQ, DK, DV, #
L, M, #
D, #
stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vk, stride_vn, #
Z, H, N_CTX, #
num_block, #
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
BLOCK_N: tl.constexpr, #
):
off_hz = tl.program_id(0)
off_z = off_hz // H
off_h = off_hz % H
@@ -240,16 +236,16 @@ class _attention(torch.autograd.Function):
assert num_warps == 4
_fwd_kernel[grid](
q, k, v, sm_scale,
L, m,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk,
q, k, v, sm_scale, #
L, m, #
o, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
q.shape[0], q.shape[1], q.shape[2], #
BLOCK_M=BLOCK, BLOCK_N=BLOCK, #
BLOCK_DMODEL=Lk #
)
ctx.save_for_backward(q, k, v, o, L, m)
@@ -269,24 +265,23 @@ class _attention(torch.autograd.Function):
do_scaled = torch.empty_like(do)
delta = torch.empty_like(l)
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do, l,
do_scaled, delta,
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
l, m,
delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
ctx.grid[0],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
num_stages=1,
o, do, l, #
do_scaled, delta, #
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL)
_bwd_kernel[(ctx.grid[1], )](
q, k, v, ctx.sm_scale, #
o, do_scaled, #
dq, dk, dv, #
l, m, #
delta, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
q.shape[0], q.shape[1], q.shape[2], #
ctx.grid[0], #
BLOCK_M=BLOCK, BLOCK_N=BLOCK, #
BLOCK_DMODEL=ctx.BLOCK_DMODEL, #
num_warps=8, num_stages=1 #
)
return dq, dk, dv, None
@@ -339,19 +334,19 @@ except BaseException:
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4
configs = [triton.testing.Benchmark(
x_names=['N_CTX'],
# x_vals=[2**i for i in range(10, 14)],
x_vals=[2**i for i in range(10, 11)],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
# ) for mode in ['fwd', 'bwd']]
) for mode in ['fwd']]
configs = [
triton.testing.Benchmark(
x_names=['N_CTX'],
# x_vals=[2**i for i in range(10, 14)],
x_vals=[2**i
for i in range(10, 11)], line_arg='provider', line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), styles=[('red', '-'), ('blue', '-')], ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
# ) for mode in ['fwd', 'bwd']]
)
for mode in ['fwd']
]
@triton.testing.perf_report(configs)
@@ -374,9 +369,8 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
if provider == "flash":
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
cu_seqlens = torch.zeros(
(BATCH + 1,), device=device, dtype=torch.int32)
lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device)
cu_seqlens = torch.zeros((BATCH + 1, ), device=device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0)
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)

View File

@@ -29,14 +29,14 @@ import triton.language as tl
@triton.jit
def static_persistent_matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
NUM_SM: tl.constexpr,
def static_persistent_matmul_kernel( #
a_ptr, b_ptr, c_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
NUM_SM: tl.constexpr #
):
start_tile = tl.program_id(axis=0)
m_tiles = tl.cdiv(M, BLOCK_M)
@@ -68,14 +68,14 @@ def static_persistent_matmul_kernel(
@triton.jit
def static_persistent_tma_matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
NUM_SM: tl.constexpr,
def static_persistent_tma_matmul_kernel( #
a_ptr, b_ptr, c_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
NUM_SM: tl.constexpr #
):
start_tile = tl.program_id(axis=0)
m_tiles = tl.cdiv(M, BLOCK_M)
@@ -88,8 +88,10 @@ def static_persistent_tma_matmul_kernel(
block_offset_m = pre_pid_m * BLOCK_M
block_offset_n = pre_pid_n * BLOCK_N
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
for tile_id in range(start_tile, num_tiles, NUM_SM):
pid_m = tile_id // n_tiles
pid_n = tile_id % n_tiles
@@ -114,21 +116,23 @@ def static_persistent_tma_matmul_kernel(
pre_pid_n = pid_n
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA',
[(*shape, use_tma)
for shape in [
[4096, 4096, 64, 64, 64, 16, 4, 1, False, True],
[4096, 4096, 64, 64, 64, 32, 4, 1, False, True],
[4096, 4096, 64, 256, 64, 16, 4, 1, False, True],
[4096, 4096, 64, 128, 128, 16, 4, 1, False, True],
# TODO: fix issue for 8-warp persistent kernel
# [4096, 4096, 64, 128, 128, 16, 8, 1, False, True],
# [4096, 4096, 64, 128, 256, 16, 8, 1, False, True],
]
for use_tma in [False, True]
])
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA', [(
*shape, use_tma
) for shape in [
[4096, 4096, 64, 64, 64, 16, 4, 1, False, True],
[4096, 4096, 64, 64, 64, 32, 4, 1, False, True
],
[4096, 4096, 64, 256, 64, 16, 4, 1, False, True
],
[4096, 4096, 64, 128, 128, 16, 4, 1, False, True
],
# TODO: fix issue for 8-warp persistent kernel
# [4096, 4096, 64, 128, 128, 16, 8, 1, False, True],
# [4096, 4096, 64, 128, 256, 16, 8, 1, False, True],
] for use_tma in [False, True]])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, TRANS_A, TRANS_B, USE_TMA):
def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS,
TRANS_A, TRANS_B, USE_TMA):
if (TRANS_A):
a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T
else:
@@ -141,25 +145,33 @@ def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLO
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),)
grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), )
if USE_TMA:
static_persistent_tma_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS)
static_persistent_tma_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0),
stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS,
num_ctas=NUM_CTAS)
else:
static_persistent_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS)
static_persistent_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0),
stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS,
num_ctas=NUM_CTAS)
th_c = torch.matmul(a, b)
torch.testing.assert_close(th_c, c, atol=1e-2, rtol=0, check_dtype=False)
@triton.jit
def warp_specialized_matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
def warp_specialized_matmul_kernel( #
a_ptr, b_ptr, c_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
):
tid = tl.program_id(axis=0)
n_tiles = tl.cdiv(N, BLOCK_N)
@@ -193,13 +205,13 @@ def warp_specialized_matmul_kernel(
@triton.jit
def tma_warp_specialized_matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
def tma_warp_specialized_matmul_kernel( #
a_ptr, b_ptr, c_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
):
tid = tl.program_id(axis=0)
n_tiles = tl.cdiv(N, BLOCK_N)
@@ -232,8 +244,7 @@ def tma_warp_specialized_matmul_kernel(
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA',
[(*shape, use_tma)
for shape in [
[(*shape, use_tma) for shape in [
[2048, 2048, 64, 64, 64, 16, 1, False, True],
[4096, 4096, 64, 64, 64, 16, 1, False, True],
[128, 4096, 64, 64, 64, 16, 1, False, True],
@@ -257,9 +268,7 @@ def tma_warp_specialized_matmul_kernel(
[4096, 4096, 128, 256, 128, 64, 4, False, True],
[4096, 4096, 256, 128, 256, 64, 4, False, True],
[4096, 4096, 256, 256, 256, 64, 4, False, True],
]
for use_tma in [False, True]
])
] for use_tma in [False, True]])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B, USE_TMA):
if (TRANS_A):
@@ -274,29 +283,29 @@ def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
grid = lambda META: (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),)
grid = lambda META: (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), )
if USE_TMA:
tma_warp_specialized_matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_M, BLOCK_N, BLOCK_K,
num_warps=4,
num_ctas=NUM_CTAS,
a, b, c, #
M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
BLOCK_M, BLOCK_N, BLOCK_K, #
num_warps=4, #
num_ctas=NUM_CTAS, #
enable_warp_specialization=True)
else:
warp_specialized_matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_M, BLOCK_N, BLOCK_K,
num_warps=4,
num_ctas=NUM_CTAS,
a, b, c, #
M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
BLOCK_M, BLOCK_N, BLOCK_K, #
num_warps=4, #
num_ctas=NUM_CTAS, #
enable_warp_specialization=True)
th_c = torch.matmul(a, b)
@@ -304,14 +313,14 @@ def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K
@triton.jit
def static_persistent_warp_specialized_matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
NUM_SM: tl.constexpr,
def static_persistent_warp_specialized_matmul_kernel( #
a_ptr, b_ptr, c_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
NUM_SM: tl.constexpr #
):
start_tile = tl.program_id(axis=0)
m_tiles = tl.cdiv(M, BLOCK_M)
@@ -343,14 +352,14 @@ def static_persistent_warp_specialized_matmul_kernel(
@triton.jit
def static_persistent_tma_warp_specialized_matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
NUM_SM: tl.constexpr,
def static_persistent_tma_warp_specialized_matmul_kernel( #
a_ptr, b_ptr, c_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
NUM_SM: tl.constexpr #
):
start_tile = tl.program_id(axis=0)
m_tiles = tl.cdiv(M, BLOCK_M)
@@ -363,8 +372,10 @@ def static_persistent_tma_warp_specialized_matmul_kernel(
block_offset_m = pre_pid_m * BLOCK_M
block_offset_n = pre_pid_n * BLOCK_N
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
for tile_id in range(start_tile, num_tiles, NUM_SM):
pid_m = tile_id // n_tiles
pid_n = tile_id % n_tiles
@@ -390,8 +401,7 @@ def static_persistent_tma_warp_specialized_matmul_kernel(
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA',
[(*shape, use_tma)
for shape in [
[(*shape, use_tma) for shape in [
[2048, 2048, 64, 64, 64, 16, 1, False, True],
[4096, 4096, 64, 64, 64, 16, 1, False, True],
[128, 4096, 64, 64, 64, 16, 1, False, True],
@@ -415,11 +425,10 @@ def static_persistent_tma_warp_specialized_matmul_kernel(
[4096, 4096, 128, 256, 128, 64, 4, False, True],
[4096, 4096, 256, 128, 256, 64, 4, False, True],
[4096, 4096, 256, 256, 256, 64, 4, False, True],
]
for use_tma in [False, True]
])
] for use_tma in [False, True]])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B, USE_TMA):
def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B,
USE_TMA):
if (TRANS_A):
a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T
else:
@@ -432,27 +441,22 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),)
grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), )
if USE_TMA:
static_persistent_tma_warp_specialized_matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs,
num_warps=4, num_ctas=NUM_CTAS,
a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M,
BLOCK_N, BLOCK_K, num_SMs, num_warps=4, num_ctas=NUM_CTAS, #
enable_warp_specialization=True)
else:
static_persistent_warp_specialized_matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs,
num_warps=4, num_ctas=NUM_CTAS,
a, b, c, #
M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs, #
num_warps=4, num_ctas=NUM_CTAS, #
enable_warp_specialization=True)
th_c = torch.matmul(a, b)
@@ -460,16 +464,15 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N
@triton.jit
def static_persistent_matmul_no_scf_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr,
NUM_SM: tl.constexpr, USE_TMA_LOAD: tl.constexpr,
):
def static_persistent_matmul_no_scf_kernel(a_ptr, b_ptr, c_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr, #
NUM_SM: tl.constexpr, USE_TMA_LOAD: tl.constexpr #
):
start_tile = tl.program_id(axis=0)
m_tiles = tl.cdiv(M, BLOCK_M)
n_tiles = tl.cdiv(N, BLOCK_N)
@@ -487,7 +490,8 @@ def static_persistent_matmul_no_scf_kernel(
offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
if USE_TMA_EPILOGUE:
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0))
for tile_id in range(start_tile, num_tiles, NUM_SM):
pid_m = tile_id // n_tiles
@@ -524,29 +528,27 @@ def static_persistent_matmul_no_scf_kernel(
pre_pid_n = pid_n
@pytest.mark.parametrize('M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,USE_TMA_LOAD',
itertools.chain(
*[
[
# numCTAs = 1, no TMA multicast:
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
# small M, N
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
] for USE_TMA_EPILOGUE in [True, False]
for USE_TMA_LOAD in [True, False]
]))
@pytest.mark.parametrize(
'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,USE_TMA_LOAD',
itertools.chain(*[[
# numCTAs = 1, no TMA multicast:
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
# small M, N
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
] for USE_TMA_EPILOGUE in [True, False] for USE_TMA_LOAD in [True, False]]))
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_static_persistent_matmul_no_scf_kernel(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE, USE_TMA_LOAD):
def test_static_persistent_matmul_no_scf_kernel(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE,
USE_TMA_EPILOGUE, USE_TMA_LOAD):
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
else:
@@ -564,46 +566,42 @@ def test_static_persistent_matmul_no_scf_kernel(M, N, K, NUM_CTAS, NUM_WARPS, TR
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
# TODO: set `enable_warp_specialization=False` will lead to compilation error.
static_persistent_matmul_no_scf_kernel[(num_SMs,)](a_ptr=a, b_ptr=b, c_ptr=c,
M=M, N=N, K=K,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1),
BLOCK_M=M if M < 128 else M // 2, BLOCK_N=N if N < 128 else N // 2, BLOCK_K=K, NUM_SM=num_SMs,
num_warps=NUM_WARPS,
num_ctas=NUM_CTAS,
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"),
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE,
USE_TMA_LOAD=USE_TMA_LOAD,
enable_warp_specialization=True)
static_persistent_matmul_no_scf_kernel[(num_SMs, )](
a_ptr=a, b_ptr=b, c_ptr=c, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_cm=c.stride(0), stride_cn=c.stride(1), #
BLOCK_M=M if M < 128 else M // 2, BLOCK_N=N if N < 128 else N // 2, BLOCK_K=K, NUM_SM=num_SMs, #
num_warps=NUM_WARPS, #
num_ctas=NUM_CTAS, #
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), #
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE, #
USE_TMA_LOAD=USE_TMA_LOAD, #
enable_warp_specialization=True)
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
golden = torch.matmul(a_f32, b_f32)
torch.set_printoptions(profile="full")
assert_close(
c,
golden,
rtol=1e-2,
atol=1e-3,
check_dtype=False)
assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
@triton.jit
def full_static_persistent_matmul_kernel(
a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_wm, stride_wn,
stride_zm, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr,
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr,
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr,
NUM_SM: tl.constexpr
):
def full_static_persistent_matmul_kernel(a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_wm, stride_wn, #
stride_zm, stride_zn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, #
out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr, #
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, #
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, #
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, #
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr, #
NUM_SM: tl.constexpr #
):
start_pid = tl.program_id(axis=0)
num_pid_n = tl.cdiv(N, BLOCK_N)
num_pid_m = tl.cdiv(M, BLOCK_M)
@@ -618,15 +616,18 @@ def full_static_persistent_matmul_kernel(
pre_block_offset_m = pre_pid_m * BLOCK_M
pre_block_offset_n = pre_pid_n * BLOCK_N
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(pre_block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(A_ORDER_0, A_ORDER_1))
offsets=(pre_block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K),
order=(A_ORDER_0, A_ORDER_1))
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, pre_block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(B_ORDER_0, B_ORDER_1))
offsets=(0, pre_block_offset_n), block_shape=(BLOCK_K, BLOCK_N),
order=(B_ORDER_0, B_ORDER_1))
w_tile_ptr = tl.make_block_ptr(base=w_ptr, shape=(N, N), strides=(stride_wm, stride_wn),
offsets=(0, pre_block_offset_n), block_shape=(BLOCK_N, BLOCK_N), order=(0, 1))
if USE_TMA_STORE:
z_block_ptr = tl.make_block_ptr(base=z_ptr, shape=(M, N), strides=(stride_zm, stride_zn),
offsets=(pre_block_offset_m, pre_block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
offsets=(pre_block_offset_m, pre_block_offset_n),
block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
for tile_id in range(start_pid, num_tiles, NUM_SM):
group_id = tile_id // num_pid_in_group
@@ -694,136 +695,120 @@ def full_static_persistent_matmul_kernel(
pre_pid_n = pid_n
@pytest.mark.parametrize('BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws)
for shape_w_c in [
[4096, 1, 1024, False, False],
[2048, 204, 1000, True, False],
[16, 524288, 32, False, True],
]
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for enable_ws in [True]
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
# softmax works for one CTA
for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
# TODO: enable when num_warps != 4 is supported.
# [64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
]
for epilogue in ['softmax']
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for trans_a in [False,]
for trans_b in [True,]
for num_stages in [3]
for enable_ws in [True]
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
]
for out_dtype in ['float32',]
for use_tma_store in [False,]
for trans_a in [False, True]
for trans_b in [False, True]
for num_stages in [3]
for enable_ws in [True]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4] for num_ctas in [1, 2]],
# # TODO: enable when num_warps != 4 is supported.
# # repeat
# # [64, 64, 32, 8, 1, 128, 256, 64],
# # [64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 1, 513, 193, 192],
]
for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot']
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for trans_a in [False,]
for trans_b in [True,]
for num_stages in [3]
for enable_ws in [True]
if not (epilogue == 'chain-dot' and (shape_w_c[5] is not None or shape_w_c[0] != shape_w_c[1]))
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
for n in [16, 32, 64, 128, 256]
for out_dtype in ['float32']
for use_tma_store in [False,]
for num_stages in [2, 4, 5, 7]
for enable_ws in [True]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2]
]
for shape in [
[512, 360, 1024],
[360, 4096, 512],
]
for out_dtype in ['float32']
for use_tma_store in [False, True]
for num_stages in [3, 4]
for enable_ws in [True]
]
)
@pytest.mark.skipif(torch.cuda.get_device_capability()
[0] < 9, reason="Requires compute capability >= 9")
def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS])) in [
'128-128-128-4-1-256-256-192-none-float32-True-3-True',
]:
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws) for shape_w_c in [
[4096, 1, 1024, False, False],
[2048, 204, 1000, True, False],
[16, 524288, 32, False, True],
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for enable_ws in [True]
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
# softmax works for one CTA
for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
# TODO: enable when num_warps != 4 is supported.
# [64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
]
for epilogue in ['softmax']
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for trans_a in [False]
for trans_b in [True]
for num_stages in [3]
for enable_ws in [True]
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, 'none', out_dtype, use_tma_store, num_stages, enable_ws) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for num_stages in [3] for enable_ws in [True]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4] for num_ctas in [1, 2]],
# # TODO: enable when num_warps != 4 is supported.
# # repeat
# # [64, 64, 32, 8, 1, 128, 256, 64],
# # [64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 1, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'] for out_dtype in
['float16', 'float32'] for use_tma_store in [False, True] for trans_a in [False] for trans_b in [True] for
num_stages in [3] for enable_ws in [True] if not (epilogue == 'chain-dot' and
(shape_w_c[5] is not None or shape_w_c[0] != shape_w_c[1]))
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
for n in [16, 32, 64, 128, 256]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
for enable_ws in [True]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [[128, 128, 64, 4, 1], [256, 128, 64, 4, 2], [128, 128, 128, 4, 2]]
for shape in [
[512, 360, 1024],
[360, 4096, 512],
]
for out_dtype in ['float32']
for use_tma_store in [False, True]
for num_stages in [3, 4]
for enable_ws in [True]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B,
epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS):
if '-'.join(
map(str, [
BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES,
ENABLE_WS
])) in [
'128-128-128-4-1-256-256-192-none-float32-True-3-True',
]:
pytest.skip('out of resource: shared memory, Required: 263168')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
@@ -893,37 +878,36 @@ def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WAR
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
def grid(META):
return (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),)
return (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), )
full_static_persistent_matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z,
M=M, N=N, K=K,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_wm=w.stride(0), stride_wn=w.stride(1),
stride_zm=z.stride(0), stride_zn=z.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8,
out_dtype=out_dtype,
USE_TMA_STORE=USE_TMA_STORE,
ADD_MATRIX=epilogue == 'add-matrix',
ADD_ROWS=epilogue == 'add-rows',
ADD_COLS=epilogue == 'add-cols',
DO_SOFTMAX=epilogue == 'softmax',
CHAIN_DOT=epilogue == 'chain-dot',
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1],
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1],
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES,
enable_warp_specialization=ENABLE_WS,
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES, #
enable_warp_specialization=ENABLE_WS, #
NUM_SM=num_SMs)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden,
rtol=1e-2,
atol=1e-3,
check_dtype=False)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)

View File

@@ -19,7 +19,6 @@
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import pytest
import torch
from torch.testing import assert_close
@@ -29,21 +28,21 @@ import triton.language as tl
@triton.jit
def matmul_tma_load_store(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
OUTPUT_F16: tl.constexpr
def matmul_tma_load_store( #
a_ptr, b_ptr, c_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
OUTPUT_F16: tl.constexpr #
):
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, 0), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, 0),
block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
a = tl.load(a_block_ptr)
b = tl.load(b_block_ptr)
@@ -78,15 +77,15 @@ def test_tma_load_store(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_F
if OUTPUT_F16:
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
matmul_tma_load_store[(1, 1)](a_ptr=a, b_ptr=b, c_ptr=c,
M=M, N=N, K=K,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1),
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K,
num_warps=NUM_WARPS,
num_ctas=NUM_CTAS,
OUTPUT_F16=OUTPUT_F16)
matmul_tma_load_store[(1, 1)](
a_ptr=a, b_ptr=b, c_ptr=c, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_cm=c.stride(0), stride_cn=c.stride(1), #
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, #
OUTPUT_F16=OUTPUT_F16)
golden = torch.matmul(a, b)
torch.set_printoptions(profile="full")
assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False)

View File

@@ -54,17 +54,13 @@ def test_tma_wgmma_64_64_16_f16(TTGIR, TRANS_A, TRANS_B):
ttgir_path = os.path.dirname(__file__) + "/" + TTGIR
kernel = triton.compile(ttgir_path)
kernel[(1, 1, 1)](a.data_ptr(), b.data_ptr(), c.data_ptr(),
SIZE_M, SIZE_N, SIZE_K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0))
kernel[(1, 1, 1)]( #
a.data_ptr(), b.data_ptr(), c.data_ptr(), #
SIZE_M, SIZE_N, SIZE_K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0))
golden = torch.matmul(a, b)
torch.set_printoptions(profile="full", sci_mode=False)
assert_close(
c,
golden,
rtol=1e-2,
atol=1e-3,
check_dtype=False)
assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False)

View File

@@ -15,9 +15,9 @@ def kernel_device_assert(X, Y, BLOCK: tl.constexpr):
@triton.jit
def kernel_device_assert_scalar(X, Y, BLOCK: tl.constexpr):
def kernel_assert_passes(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
# Trivial assert
# Trivial assert, should not be an error.
tl.device_assert(0 == 0, "x != 0")
tl.store(Y + tl.arange(0, BLOCK), x)
@@ -48,6 +48,7 @@ def test_assert(func: str):
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
if func == "device_assert":
<<<<<<< HEAD
kernel_device_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
kernel_device_assert_scalar[(1,)](x, y, num_warps=2, BLOCK=shape[0])
elif func == "no_debug":
@@ -55,8 +56,32 @@ def test_assert(func: str):
kernel_device_assert_no_debug[(1,)](x, y, num_warps=2, BLOCK=shape[0])
elif func == "assert":
kernel_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
=======
kernel_device_assert[(1, )](x, y, BLOCK=shape[0])
if func == "device_assert_passes":
# Assert passes; no error.
kernel_assert_passes[(1, )](x, y, BLOCK=shape[0])
elif func == "no_debug":
# TRITON_DEBUG=1 can override the debug flag
kernel_device_assert_no_debug[(1, )](x, y, BLOCK=shape[0])
elif func == "assert":
kernel_assert[(1, )](x, y, BLOCK=shape[0])
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
elif func == "static_assert":
kernel_static_assert[(1,)](x, y, BLOCK=shape[0])
kernel_static_assert[(1, )](x, y, BLOCK=shape[0])
elif func == "double_assert":
# Launching a different kernel after the first one asserted used to
# segfault. What seems to have happened is:
# - The first kernel is enqueued but doesn't run yet.
# - We go to launch the second kernel. Because this is the first time
# we're running it, we have to load the kernel into the GPU.
# - Loading the kernel takes some time, during which the first launch
# completes.
# - Now the GPU is in an error state. We need to detect this inside
# the kernel-launch/loading code and bail out properly. If we don't,
# we segfault.
kernel_device_assert[(1, )](x, y, BLOCK=shape[0])
kernel_assert_passes[(1, )](x, y, BLOCK=shape[0])
assert_close(y, x)
@@ -116,11 +141,19 @@ def test_assert_nested(caller: str, callee: str):
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
if caller == "none":
<<<<<<< HEAD
kernel_device_assert_nested[(1,)](x, y, num_warps=2, BLOCK=shape[0], jit_debug=callee)
elif caller == "true":
kernel_device_assert_nested_true[(1,)](x, y, num_warps=2, BLOCK=shape[0], jit_debug=callee)
elif caller == "false":
kernel_device_assert_nested_false[(1,)](x, y, num_warps=2, BLOCK=shape[0], jit_debug=callee)
=======
kernel_device_assert_nested[(1, )](x, y, BLOCK=shape[0], jit_debug=callee)
elif caller == "true":
kernel_device_assert_nested_true[(1, )](x, y, BLOCK=shape[0], jit_debug=callee)
elif caller == "false":
kernel_device_assert_nested_false[(1, )](x, y, BLOCK=shape[0], jit_debug=callee)
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
assert_close(y, x)

View File

@@ -4,9 +4,7 @@ import pytest
def pytest_addoption(parser):
parser.addoption(
"--device", action="store", default='cuda'
)
parser.addoption("--device", action="store", default='cuda')
@pytest.fixture

View File

@@ -1,4 +1,5 @@
import sys
import uuid
import torch
from torch.testing import assert_close
@@ -10,21 +11,49 @@ import triton.language as tl
@triton.jit
def kernel_device_print(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.device_print("", x)
tl.device_print("x: ", x)
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit
def kernel_print(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
print("", x)
# Triton should add a space after this prefix.
print("x:", x)
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit
def kernel_static_print(X, Y, BLOCK: tl.constexpr):
def kernel_device_print_large(
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
x = tl.full([BLOCK_M, BLOCK_N], 1, tl.int32)
# Triton should change this prefix to "x: ".
tl.device_print("x ", x)
@triton.jit
def kernel_print_multiple_args(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.static_print(x)
y = tl.full((BLOCK, ), 1, tl.int32)
print("", x, y)
@triton.jit
def kernel_device_print_multiple_args(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
y = tl.full((BLOCK, ), 1, tl.int32)
tl.device_print("", x, y)
tl.store(Y + tl.arange(0, BLOCK), y)
@triton.jit
def kernel_static_print(X, Y, BLOCK: tl.constexpr, PLACEHOLDER: tl.constexpr):
# This function takes an extra value as a tl.constexpr so this kernel is not
# cached. This way the static print is run every time.
x = tl.load(X + tl.arange(0, BLOCK))
tl.static_print("", x)
tl.store(Y + tl.arange(0, BLOCK), x)
@@ -33,21 +62,36 @@ def kernel_no_arg_print():
print("", tl.program_id(0))
@triton.jit
def kernel_print_no_arg():
print("no arg")
def test_print(func: str, data_type: str):
shape = (128, )
# limit the range of integers so that the sum does not overflow
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda').to(getattr(torch, data_type))
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
if func == "device_print":
kernel_device_print[(1,)](x, y, BLOCK=shape[0])
kernel_device_print[(1, )](x, y, BLOCK=shape[0])
elif func == "print":
kernel_print[(1,)](x, y, BLOCK=shape[0])
kernel_print[(1, )](x, y, BLOCK=shape[0])
elif func == "device_print_large":
kernel_device_print_large[(1, 2)](BLOCK_M=64, BLOCK_N=128)
elif func == "print_multiple_args":
kernel_print_multiple_args[(1, )](x, y, BLOCK=shape[0])
elif func == "device_print_multiple_args":
kernel_device_print_multiple_args[(1, )](x, y, BLOCK=shape[0])
elif func == "static_print":
kernel_static_print[(1,)](x, y, BLOCK=shape[0])
kernel_static_print[(1, )](x, y, BLOCK=shape[0], PLACEHOLDER=uuid.uuid4())
elif func == "no_arg_print":
kernel_no_arg_print[(1,)](num_warps=4)
kernel_no_arg_print[(1, )](num_warps=4)
elif func == "print_no_arg":
kernel_print_no_arg[(1, )](num_warps=4)
else:
assert f"Unknown kernel: {func}"
if func != "no_arg_print":
if func != "print_no_arg" and func != "no_arg_print" and func != "device_print_large" and \
func != "print_multiple_args" and func != "device_print_multiple_args":
assert_close(y, x)

View File

@@ -1,4 +1,3 @@
from __future__ import annotations
import torch
@@ -14,8 +13,8 @@ def test_annotations(device):
pass
x = torch.empty(1, device=device)
_kernel[(1,)](x, x.shape[0], 32)
_kernel[(1, )](x, x.shape[0], 32)
try:
_kernel[(1,)](x.shape[0], x.shape[0], 32)
_kernel[(1, )](x.shape[0], x.shape[0], 32)
except AttributeError:
pass

View File

@@ -17,10 +17,12 @@ def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, padding_option:
tl.store(b_block_ptr, a, boundary_check=(0, ))
@pytest.mark.parametrize("dtype_str, n, padding_option",
[(dtype_str, n, padding) for dtype_str in ("bool", "int16", "float16")
for n in (64, 128, 256, 512, 1024)
for padding in ("zero", "nan")])
@pytest.mark.parametrize("dtype_str, n, padding_option", [ #
(dtype_str, n, padding)
for dtype_str in ("bool", "int16", "float16")
for n in (64, 128, 256, 512, 1024)
for padding in ("zero", "nan") #
])
def test_block_copy(dtype_str, n, padding_option):
capability = torch.cuda.get_device_capability()
if torch.version.hip is None and capability[0] >= 9:
@@ -35,31 +37,31 @@ def test_block_copy(dtype_str, n, padding_option):
a = torch.randn((n, ), device="cuda", dtype=dtype)
b = torch.zeros((n, ), device="cuda", dtype=dtype)
grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),)
grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]), )
block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64, padding_option=padding_option)
assert torch.all(a[0: n // 2] == b[0: n // 2])
assert torch.all(a[0:n // 2] == b[0:n // 2])
if padding_option == "zero":
assert torch.all(b[n // 2: n] == 0)
assert torch.all(b[n // 2:n] == 0)
else:
assert torch.all(torch.isnan(b[n // 2: n]))
assert torch.all(torch.isnan(b[n // 2:n]))
@triton.jit
def matmul_no_scf_with_advance_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr
def matmul_no_scf_with_advance_kernel( #
a_ptr, b_ptr, c_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr #
):
offs_m = tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, 0), block_shape=(BLOCK_K, BLOCK_N), order=(1, 0))
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, 0),
block_shape=(BLOCK_K, BLOCK_N), order=(1, 0))
# Below two lines are just for testing negative offsets for the `advance` API, which could be removed
a_block_ptr = tl.advance(a_block_ptr, (BLOCK_M, -BLOCK_K))
a_block_ptr = tl.advance(a_block_ptr, (-BLOCK_M, BLOCK_K))
@@ -71,14 +73,12 @@ def matmul_no_scf_with_advance_kernel(
tl.store(c_ptrs, c)
@pytest.mark.parametrize("shape, num_warps", [
(shape, num_warps)
for shape in [
@pytest.mark.parametrize("shape, num_warps", [ #
(shape, num_warps) for shape in [
[64, 64, 16],
[64, 64, 32],
[64, 64, 64],
]
for num_warps in [4, 8]
] for num_warps in [4, 8]
])
def test_block_ptr_matmul_no_scf(shape, num_warps):
capability = torch.cuda.get_device_capability()
@@ -91,12 +91,13 @@ def test_block_ptr_matmul_no_scf(shape, num_warps):
c = torch.empty((m, n), device="cuda", dtype=torch.float32)
grid = lambda META: (1, )
matmul_no_scf_with_advance_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
M=m, N=n, K=k,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1),
BLOCK_M=m, BLOCK_N=n, BLOCK_K=k,
num_warps=num_warps)
matmul_no_scf_with_advance_kernel[grid](
a_ptr=a, b_ptr=b, c_ptr=c, #
M=m, N=n, K=k, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_cm=c.stride(0), stride_cn=c.stride(1), #
BLOCK_M=m, BLOCK_N=n, BLOCK_K=k, #
num_warps=num_warps)
golden = torch.matmul(a, b)
torch.testing.assert_close(c, golden, check_dtype=False)

File diff suppressed because it is too large Load Diff

View File

@@ -12,6 +12,7 @@ import triton.language as tl
class PhiloxConfig:
def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE):
self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE)
self.PHILOX_ROUND_B = np.array(PHILOX_ROUND_B, dtype=DTYPE)
@@ -40,6 +41,7 @@ PHILOX_64 = PhiloxConfig(
class CustomPhilox4x:
def __init__(self, seed, config):
self._config = config
seed = self._into_pieces(seed)
@@ -92,6 +94,7 @@ class CustomPhilox4x:
class CustomPhilox(CustomPhilox4x):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.buffer = []
@@ -111,10 +114,9 @@ BLOCK = 1024
# test generation of random uint32
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in ['10', '4,53', '10000']
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
)
@pytest.mark.parametrize('size, seed', [(size, seed)
for size in ['10', '4,53', '10000']
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]])
def test_randint(size, seed, device):
size = list(map(int, size.split(',')))
@@ -123,10 +125,11 @@ def test_randint(size, seed, device):
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
rand = tl.randint(seed, offset)
tl.store(X + offset, rand, mask=offset < N)
# triton result
x = torch.empty(size, dtype=torch.int32, device=device)
N = x.numel()
grid = (triton.cdiv(N, BLOCK),)
grid = (triton.cdiv(N, BLOCK), )
kernel[grid](x, N, seed)
out_tri = x.cpu().numpy().astype(np.uint32).flatten().tolist()
# reference result
@@ -134,44 +137,44 @@ def test_randint(size, seed, device):
out_ref = [gen.random_raw()[0] for _ in out_tri]
assert out_tri == out_ref
# test uniform PRNG
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]
for seed in [0, 42, 124, 54]]
)
@pytest.mark.parametrize('size, seed', [(size, seed) for size in [1000000] for seed in [0, 42, 124, 54]])
def test_rand(size, seed, device):
@triton.jit
def kernel(X, N, seed):
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
rand = tl.rand(seed, offset)
tl.store(X + offset, rand, mask=offset < N)
# triton result
x = torch.empty(size, dtype=torch.float32, device=device)
N = x.numel()
grid = (triton.cdiv(N, BLOCK),)
grid = (triton.cdiv(N, BLOCK), )
kernel[grid](x, N, seed)
assert all((x >= 0) & (x <= 1))
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
# test normal PRNG
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]
for seed in [0, 42, 124, 54]]
)
@pytest.mark.parametrize('size, seed', [(size, seed) for size in [1000000] for seed in [0, 42, 124, 54]])
def test_randn(size, seed, device):
@triton.jit
def kernel(X, N, seed):
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
rand = tl.randn(seed, offset)
tl.store(X + offset, rand, mask=offset < N)
# triton result
x = torch.empty(size, dtype=torch.float32, device=device)
N = x.numel()
grid = (triton.cdiv(N, BLOCK),)
grid = (triton.cdiv(N, BLOCK), )
kernel[grid](x, N, seed)
assert abs(x.mean()) < 1e-2
assert abs(x.std() - 1) < 1e-2
@@ -179,7 +182,9 @@ def test_randn(size, seed, device):
# tl.rand() should never produce >=1.0
def test_rand_limits(device):
@triton.jit
def kernel(input, output, n: tl.constexpr):
idx = tl.arange(0, n)
@@ -192,7 +197,7 @@ def test_rand_limits(device):
torch.iinfo(torch.int32).max,
], dtype=torch.int32, device=device)
output = torch.empty(2, dtype=torch.float32, device=device)
kernel[(1,)](min_max_int32, output, 2)
kernel[(1, )](min_max_int32, output, 2)
assert output[0] == output[1]
assert 1.0 - torch.finfo(torch.float32).eps <= output[0].item() < 1.0

View File

@@ -1,6 +1,8 @@
import itertools
import os
import subprocess
import sys
from collections import Counter
import pytest
@@ -9,53 +11,88 @@ print_path = os.path.join(dir_path, "print_helper.py")
assert_path = os.path.join(dir_path, "assert_helper.py")
# TODO: bfloat16 after LLVM-15
assert_types = ["device_assert", "assert", "static_assert", "no_debug"]
assert_types = ["device_assert", "device_assert_passes", "assert", "static_assert", "no_debug", "double_assert"]
nested_types = [(caller, callee) for caller in ["true", "false", "none"] for callee in ["true", "false", "none"]]
torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"]
@pytest.mark.parametrize("func_type, data_type",
[("device_print", data_type) for data_type in torch_types] + [("print", "int32"), ("static_print", "int32"), ("no_arg_print", "int32")])
# TODO: Print with multiple operands
@pytest.mark.parametrize("func_type, data_type", [("device_print", data_type) for data_type in torch_types] + [
("print", "int32"),
("static_print", "int32"),
("no_arg_print", "int32"),
("print_no_arg", "int32"),
("device_print_large", "int32"),
("print_multiple_args", "int32"),
("device_print_multiple_args", "int32"),
])
def test_print(func_type: str, data_type: str):
proc = subprocess.Popen([sys.executable, print_path, func_type, data_type], stdout=subprocess.PIPE, shell=False)
outs, _ = proc.communicate()
outs = outs.split()
new_lines = set()
for line in outs:
try:
value = line
if func_type != "static_print":
value = int(float(line))
new_lines.add(value)
except Exception as e:
print(e)
if func_type != "static_print" and func_type != "no_arg_print":
outs = [line for line in outs.decode("UTF-8").split("\n") if line]
# Format is
# pid (<x>, <y>, <z>) idx (<i1>, <i2>, ...) <prefix> (operand <n>) <elem>
expected_lines = Counter()
if func_type == "print" or func_type == "device_print":
for i in range(128):
assert i in new_lines
else:
assert len(new_lines) == 1
line = f"pid (0, 0, 0) idx ({i:3}) x: {i}"
if data_type.startswith("float"):
line += ".000000"
expected_lines[line] = 1
elif func_type == "static_print":
expected_lines[" int32[constexpr[128]]"] = 1
elif func_type == "no_arg_print":
expected_lines["pid (0, 0, 0) idx (): 0"] = 128
elif func_type == "print_no_arg":
expected_lines["pid (0, 0, 0) no arg"] = 128
elif func_type == "device_print_large":
for i, j, k in itertools.product(range(2), range(64), range(128)):
expected_lines[f"pid (0, {i}, 0) idx ({j:2}, {k:3}) x: 1"] = 1
elif func_type == "print_multiple_args" or func_type == "device_print_multiple_args":
for i in range(128):
expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 0) {i}"] = 1
expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 1) 1"] = 1
actual_lines = Counter()
for line in outs:
actual_lines[line] += 1
diff = Counter(actual_lines)
diff.subtract(expected_lines)
for line, delta in diff.items():
if delta == 0:
continue
print(f'Expected line "{line}" {expected_lines[line]} time(s), but saw {actual_lines[line]} time(s)')
assert all(delta == 0 for delta in diff.values())
@pytest.mark.parametrize("func_type", assert_types)
def test_assert(func_type: str):
os.environ["TRITON_DEBUG"] = "1"
proc = subprocess.Popen([sys.executable, assert_path, func_type], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
proc = subprocess.Popen([sys.executable, assert_path, func_type], stdout=subprocess.PIPE, stderr=subprocess.PIPE,
shell=False)
_, errs = proc.communicate()
errs = errs.splitlines()
num_errs = 0
for err in errs:
if "x != 0" in err.decode("utf-8"):
num_errs += 1
# Check for segfaults.
assert all("segmentation fault" not in line.decode("utf-8").lower() for line in errs)
os.environ["TRITON_DEBUG"] = "0"
if func_type != "static_assert":
assert num_errs == 127
else:
if func_type == "static_assert" or func_type == "device_assert_passes":
assert num_errs == 0
else:
assert num_errs == 127
@pytest.mark.parametrize("caller_type, callee_type", nested_types)
def test_assert_nested(caller_type, callee_type):
proc = subprocess.Popen([sys.executable, assert_path, caller_type, callee_type], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
proc = subprocess.Popen([sys.executable, assert_path, caller_type, callee_type], stdout=subprocess.PIPE,
stderr=subprocess.PIPE, shell=False)
_, errs = proc.communicate()
errs = errs.splitlines()
num_errs = 0

View File

@@ -68,8 +68,7 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
b_ref = do_mask(b_ref) if is_dds else b_ref
a_ref.retain_grad()
b_ref.retain_grad()
c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref,
b_ref.transpose(2, 3) if TRANS_B else b_ref)
c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref, b_ref.transpose(2, 3) if TRANS_B else b_ref)
c_ref.backward(dc_ref)
c_ref = do_sparsify(c_ref) if is_sdd else c_ref
da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad
@@ -172,7 +171,7 @@ def test_attention_fwd_bwd(
value.retain_grad()
attn_out = triton_attention(layout, block, query=query, key=key, value=value, scale=scale)
# ad hoc loss
loss = (attn_out ** 2).mean()
loss = (attn_out**2).mean()
loss.backward()
grads = [query.grad, key.grad, value.grad]
@@ -189,7 +188,7 @@ def test_attention_fwd_bwd(
probs = torch.softmax(scores, dim=-1)
torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v)
# ad hoc loss
torch_loss = (torch_attn_out ** 2).mean()
torch_loss = (torch_attn_out**2).mean()
torch_loss.backward()
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad]
@@ -209,8 +208,10 @@ def triton_attention(
value: torch.Tensor,
scale: float,
):
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device)
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device)
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True,
device=value.device)
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False,
device=value.device)
sparse_softmax = triton.ops.blocksparse.softmax(layout, block, device=value.device)
w = sparse_dot_sdd_nt(query, key)

View File

@@ -5,14 +5,13 @@ import triton
import triton.ops
@pytest.mark.parametrize("M, N, dtype, mode",
[
(M, N, dtype, mode) for M in [1024, 821]
for N in [512, 857, 1871, 2089, 8573, 31000]
for dtype in ['float16', 'float32']
for mode in ['forward', 'backward']
]
)
@pytest.mark.parametrize("M, N, dtype, mode", [ #
(M, N, dtype, mode)
for M in [1024, 821]
for N in [512, 857, 1871, 2089, 8573, 31000]
for dtype in ['float16', 'float32']
for mode in ['forward', 'backward']
])
def test_op(M, N, dtype, mode):
capability = torch.cuda.get_device_capability()
if capability[0] < 8 and dtype == "bfloat16":

View File

@@ -5,10 +5,12 @@ import triton
import triton.ops
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(2, 4, 512, 16),
(2, 4, 512, 32),
(2, 4, 512, 64),
(2, 4, 512, 128)])
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ #
(2, 4, 512, 16),
(2, 4, 512, 32),
(2, 4, 512, 64),
(2, 4, 512, 128),
])
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('causal', [True, False])
@pytest.mark.parametrize('seq_par', [True, False])
@@ -56,6 +58,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
# # triton implementation
tri_out = triton.ops.attention(q, k, v, causal, sm_scale, seq_par)
<<<<<<< HEAD
# print(ref_out)
# print(tri_out)
if torch.version.hip is None:
@@ -70,3 +73,74 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0)
torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0)
torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0)
=======
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# compare
atol = 1e-1 if dtype == torch.bfloat16 else 1e-2
torch.testing.assert_close(ref_out, tri_out, atol=atol, rtol=0)
torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0)
torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0)
torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0)
try:
from flash_attn.flash_attn_interface import flash_attn_func
HAS_FLASH = True
except BaseException:
HAS_FLASH = False
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4
configs = [
triton.testing.Benchmark(
x_names=['N_CTX'], x_vals=[2**i for i in range(10, 14)], line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), styles=[('red', '-'), ('blue', '-')], ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-{casual}-{seq_par}', args={
'H': N_HEADS,
'BATCH': BATCH,
'D_HEAD': D_HEAD,
'dtype': torch.float16,
'mode': mode,
'casual': casual,
'seq_par': seq_par,
}) for mode in ['fwd', 'bwd'] for casual in [True, False] for seq_par in [True, False]
]
@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, casual, seq_par, provider, dtype=torch.float16, device="cuda"):
assert mode in ['fwd', 'bwd']
warmup = 25
rep = 100
sm_scale = 1.3
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
if provider == "triton":
fn = lambda: triton.ops.attention(q, k, v, casual, sm_scale, seq_par)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
if provider == "flash":
lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device)
cu_seqlens = torch.zeros((BATCH + 1, ), device=device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0)
fn = lambda: flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=sm_scale, causal=casual)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
# only works on post-Ampere GPUs right now
# bench_flash_attention.run(save_path='.', print_data=True)
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52

View File

@@ -8,7 +8,8 @@ import triton.language as tl
def test_normalization_with_remat():
@triton.jit
def triton_(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):
def triton_(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, xnumel, rnumel, XBLOCK: tl.constexpr,
RBLOCK: tl.constexpr):
xnumel = 512
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
@@ -52,7 +53,7 @@ def test_normalization_with_remat():
arg115_1 = torch.rand(64, device="cuda")
arg8_1 = torch.rand(64, device="cuda")
arg9_1 = torch.rand(64, device="cuda")
triton_[(512,)](buf14, buf16, arg114_1, arg115_1, arg8_1, arg9_1, 512, 4096, 1, 2048)
triton_[(512, )](buf14, buf16, arg114_1, arg115_1, arg8_1, arg9_1, 512, 4096, 1, 2048)
torch.testing.assert_close(buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0)
@@ -148,7 +149,7 @@ def test_avg_pool_bw():
inp = torch.ones(8, 2048, 8, 8, device="cuda", dtype=torch.half)
out = torch.ones_like(inp) * 3
numel = inp.numel()
triton_[(numel // 1024,)](inp, out, 1024)
triton_[(numel // 1024, )](inp, out, 1024)
out_ref = torch.ones_like(inp)
out_ref[:, :, 1:7, 0::7] = 2 / 3
out_ref[:, :, 0::7, 1:7] = 2 / 3
@@ -159,6 +160,7 @@ def test_avg_pool_bw():
@pytest.mark.parametrize("RBLOCK", [1, 16, 32, 64, 128])
@pytest.mark.parametrize("num_warps", [1, 4])
def test_scan2d_broadcast(RBLOCK, num_warps):
@triton.jit(debug=True)
def fn(in_ptr, out_ptr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):
rindex = tl.arange(0, RBLOCK)[None, :]
@@ -172,12 +174,13 @@ def test_scan2d_broadcast(RBLOCK, num_warps):
XBLOCK = 4
input = torch.randint(0, 10, (1, RBLOCK), dtype=torch.int64, device='cuda')
output = torch.empty((XBLOCK, RBLOCK), dtype=torch.int64, device='cuda')
fn[(1,)](input, output, XBLOCK, RBLOCK, num_warps=num_warps)
fn[(1, )](input, output, XBLOCK, RBLOCK, num_warps=num_warps)
ref = input.cumsum(1).broadcast_to((XBLOCK, RBLOCK))
torch.testing.assert_close(output, ref)
def test_scan2d_for():
@triton.jit
def fn(out_ptr0, rnumel, RBLOCK: tl.constexpr):
rbase = tl.arange(0, RBLOCK)[None, :]
@@ -190,6 +193,6 @@ def test_scan2d_for():
RBLOCK = 8
out0 = torch.empty(RBLOCK, device="cuda", dtype=torch.int64)
fn[(1,)](out0, RBLOCK, RBLOCK)
fn[(1, )](out0, RBLOCK, RBLOCK)
ref = torch.arange(RBLOCK, device="cuda", dtype=torch.int64) + 1
torch.testing.assert_close(out0, ref)

View File

@@ -19,7 +19,7 @@ def f8_to_f16(x, dtype):
tl.store(Y + offs, x, mask=mask)
ret = torch.empty(x.shape, dtype=torch.float16, device=x.device)
grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']),)
grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']), )
dtype = getattr(tl, dtype)
kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024)
return ret
@@ -28,87 +28,88 @@ def f8_to_f16(x, dtype):
@pytest.mark.parametrize(
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM",
itertools.chain(
*[
[
# 1 warp
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
# 2 warp
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
# 4 warp
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
# 8 warp
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
# variable input
(128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True, True),
(128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True, True),
(128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True, True),
(128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True, True),
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]
],
*[[
# 1 warp
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
# 2 warp
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
# 4 warp
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
# 8 warp
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
# variable input
(128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True, True),
(128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True, True),
(128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True, True),
(128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True, True),
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]],
# n-stage
*[
[
(16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True, True),
(64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True, True),
(128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True, True),
(256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True, True),
(128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True, True),
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [4]
],
*[[
(16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True, True),
(64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True, True),
(128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True, True),
(256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True, True),
(128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True, True),
]
for DTYPE in ["float16", "bfloat16", "float32"]
for AT in [False, True]
for BT in [False, True]
for STAGES in [4]],
# mixed-precision
*[
[
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM),
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM),
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM),
] for ADTYPE, BDTYPE in [("float8e4nv", "float8e5"),
("float8e4nv", "float8e4nv"),
("float8e5", "float8e4nv"),
("float8e5", "float8e5"),
("float8e4b15", "float8e4b15"),
("float8e4nv", "float16"),
("float16", "float8e5"),
("float16", "float32"),
("float32", "float16"),
("bfloat16", "float32"),
("float32", "bfloat16")] for AT in [False, True] for BT in [False, True] for FASTACCUM in [True, False]
],
*[[
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM),
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM),
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM),
] for ADTYPE, BDTYPE in [
("float8e4nv", "float8e5"),
("float8e4nv", "float8e4nv"),
("float8e5", "float8e4nv"),
("float8e5", "float8e5"),
("float8e4b15", "float8e4b15"),
("float8e4nv", "float16"),
("float16", "float8e5"),
("float16", "float32"),
("float32", "float16"),
("bfloat16", "float32"),
("float32", "bfloat16"),
] for AT in [False, True] for BT in [False, True] for FASTACCUM in [True, False]],
# mixed-precision block layout
*[
[
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True),
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True),
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False, True),
] for ADTYPE, BDTYPE in [("float8e4nv", "float16"),
("float16", "float8e5"),
("float16", "float32"),
("float32", "float16"),
("bfloat16", "float32"),
("float32", "bfloat16")] for AT in [False, True] for BT in [False, True]
],
*[[
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True),
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True),
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False, True),
] for ADTYPE, BDTYPE in [
("float8e4nv", "float16"),
("float16", "float8e5"),
("float16", "float32"),
("float32", "float16"),
("bfloat16", "float32"),
("float32", "bfloat16"),
] for AT in [False, True] for BT in [False, True]],
),
)
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM):
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32,
F8_FASTACCUM):
capability = torch.cuda.get_device_capability()
if capability[0] < 7:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
@@ -147,7 +148,7 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
return torch.randint(-128, 127, (m, n), device="cuda", dtype=torch.int8)
dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[dtype]
exponents = torch.randint(-10, 0, size=(m, n))
ret = (2. ** exponents).to(dtype).to("cuda")
ret = (2.**exponents).to(dtype).to("cuda")
return ret
# allocate/transpose inputs

View File

@@ -17,6 +17,25 @@ def test_kwargs():
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(src + offsets, mask=offsets < N)
tl.store(dst + offsets, x, mask=offsets < N)
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']),)
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), )
_kernel[grid](dst, src, N)
_kernel[grid](dst=dst, src=src, N=N)
def test_restore():
N = 1024
src = torch.zeros(N, device='cuda')
configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})]
@triton.autotune(configs=configs, key=['N'], restore_value=['src'])
@triton.jit
def _kernel(src, N, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(src + offsets, mask=offsets < N) + 1
tl.store(src + offsets, x, mask=offsets < N)
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), )
_kernel[grid](src, N)
triton.testing.assert_close(src, torch.ones_like(src))

View File

@@ -80,11 +80,12 @@ def test_reuse():
def inc_counter(*args, **kwargs):
nonlocal counter
counter += 1
JITFunction.cache_hook = inc_counter
reset_tmp_dir()
x = torch.empty(1, dtype=torch.int32, device='cuda')
for i in range(10):
kernel[(1,)](x, 1, BLOCK=1024)
kernel[(1, )](x, 1, BLOCK=1024)
assert counter == 1
@@ -95,17 +96,19 @@ def test_specialize(mode):
def inc_counter(*args, **kwargs):
nonlocal counter
counter += 1
JITFunction.cache_hook = inc_counter
reset_tmp_dir()
x = torch.empty(1, dtype=torch.int32, device='cuda')
function = {'enable': kernel, 'disable': kernel_nospec}[mode]
target = {'enable': 4, 'disable': 1}[mode]
for i in [1, 2, 4, 8, 16, 32]:
function[(1,)](x, i, BLOCK=512)
function[(1, )](x, i, BLOCK=512)
assert counter == target
def test_annotation():
@triton.jit
def kernel(X, i: tl.int32):
tl.store(X, i)
@@ -113,14 +116,15 @@ def test_annotation():
x = torch.empty(1, dtype=torch.int32, device='cuda')
device = torch.cuda.current_device()
kernel[(1,)](x, 1)
kernel[(1,)](x, 8)
kernel[(1,)](x, 16)
kernel[(1,)](x, 17)
kernel[(1, )](x, 1)
kernel[(1, )](x, 8)
kernel[(1, )](x, 16)
kernel[(1, )](x, 17)
assert len(kernel.cache[device]) == 4
def test_constexpr_not_callable() -> None:
@triton.jit
def kernel(X, c: tl.constexpr):
tl.store(X, 2)
@@ -141,11 +145,11 @@ def test_constexpr_not_callable() -> None:
def test_jit_warmup_cache() -> None:
@triton.jit
def kernel_add(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.store(o + idx,
tl.load(a + idx) + tl.load(b + idx))
tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx))
args = [
torch.randn(32, dtype=torch.float32, device="cuda"),
@@ -155,31 +159,31 @@ def test_jit_warmup_cache() -> None:
]
device = torch.cuda.current_device()
assert len(kernel_add.cache[device]) == 0
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
assert len(kernel_add.cache[device]) == 1
kernel_add.warmup(*args, grid=(1,))
kernel_add.warmup(*args, grid=(1, ))
assert len(kernel_add.cache[device]) == 1
kernel_add.warmup(*args, grid=(1,))
kernel_add.warmup(*args, grid=(1, ))
assert len(kernel_add.cache[device]) == 1
def test_jit_debug() -> None:
@triton.jit
def kernel_add(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.device_assert(idx < 32, "idx < 32")
tl.store(o + idx,
tl.load(a + idx) + tl.load(b + idx))
tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx))
device = torch.cuda.current_device()
assert len(kernel_add.cache[device]) == 0
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
assert len(kernel_add.cache[device]) == 1
kernel_add.debug = False
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
assert len(kernel_add.cache[device]) == 2
kernel_add.debug = True
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
assert len(kernel_add.cache[device]) == 3
bins = list(kernel_add.cache[device].values())
assert bins[2].asm['ttir'] != bins[1].asm['ttir']
@@ -192,13 +196,14 @@ def add_fn(a, b, o, N: tl.constexpr):
def test_jit_noinline() -> None:
@triton.jit
def kernel_add_device(a, b, o, N: tl.constexpr):
add_fn(a, b, o, N)
device = torch.cuda.current_device()
assert len(kernel_add_device.cache[device]) == 0
kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
assert len(kernel_add_device.cache[device]) == 1
bins = list(kernel_add_device.cache[device].values())
inline_ttir = bins[0].asm['ttir']
@@ -206,7 +211,7 @@ def test_jit_noinline() -> None:
add_fn.hash = None
kernel_add_device.hash = None
kernel_add_device.cache[device].clear()
kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
assert len(kernel_add_device.cache[device]) == 1
bins = list(kernel_add_device.cache[device].values())
noinline_ttir = bins[0].asm['ttir']
@@ -214,6 +219,7 @@ def test_jit_noinline() -> None:
def test_memory_leak() -> None:
@triton.jit
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):
xnumel = 10

View File

@@ -31,11 +31,11 @@ def test_memory_leak() -> None:
try:
inp = torch.randn(10, device='cuda')
out = torch.randn(10, device='cuda')
kernel[(10,)](inp, out, 10, XBLOCK=16)
kernel[(10, )](inp, out, 10, XBLOCK=16)
gc.collect()
begin, _ = tracemalloc.get_traced_memory()
for _ in range(100):
kernel[(10,)](inp, out, 10, XBLOCK=16)
kernel[(10, )](inp, out, 10, XBLOCK=16)
gc.collect()
end, _ = tracemalloc.get_traced_memory()
assert end - begin < 30000

View File

@@ -17,9 +17,11 @@ def reset_tmp_dir():
shutil.rmtree(tmpdir, ignore_errors=True)
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])
instance_descriptor = namedtuple("instance_descriptor",
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])
<<<<<<< HEAD
def get_device_type():
try:
import torch
@@ -36,10 +38,15 @@ def get_device_type():
def compile_fn(config, device_type, cc):
=======
def compile_fn(config, cc):
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
@triton.jit
def kernel_sub(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777)
triton.compile(
fn=kernel_sub,
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
@@ -57,15 +64,24 @@ def test_compile_in_subproc() -> None:
config = instance_descriptor(tuple(range(4)), (), (), ())
multiprocessing.set_start_method('fork')
<<<<<<< HEAD
proc = multiprocessing.Process(
target=compile_fn,
args=(config, device_type, cc))
=======
proc = multiprocessing.Process(target=compile_fn, args=(config, cc))
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
proc.start()
proc.join()
assert proc.exitcode == 0
<<<<<<< HEAD
def compile_fn_dot(config, device_type, cc):
=======
def compile_fn_dot(config, cc):
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
@triton.jit
def kernel_dot(Z):
offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :]
@@ -90,9 +106,13 @@ def test_compile_in_forked_subproc() -> None:
config = instance_descriptor(tuple(range(1)), (), (), ())
assert multiprocessing.get_start_method() == 'fork'
<<<<<<< HEAD
proc = multiprocessing.Process(
target=compile_fn_dot,
args=(config, device_type, cc))
=======
proc = multiprocessing.Process(target=compile_fn_dot, args=(config, cc))
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
proc.start()
proc.join()
assert proc.exitcode == 0

View File

@@ -59,7 +59,7 @@ def kernel(C, A, B, M, N, K,
tl.store(c_ptrs, c)
"""
test_utils_src = '''
test_utils_src = """
#include <cuda.h>
#include <stdio.h>
#include <stdint.h>
@@ -93,23 +93,26 @@ static void read_csv_to_buffer(char *filename, int16_t *buffer, int size) {
index++;
}
fclose(file);
}'''
}"""
def gen_kernel_library(dir, libname):
c_files = glob.glob(os.path.join(dir, "*.c"))
subprocess.run(["gcc"] + c_files + ["-I", cuda_include_dir(),
"-c", "-fPIC"],
check=True, cwd=dir)
subprocess.run(
["gcc"] + c_files + ["-I", cuda_include_dir(), "-c", "-fPIC"],
check=True,
cwd=dir,
)
o_files = glob.glob(os.path.join(dir, "*.o"))
subprocess.run(["gcc"] + o_files + ["-shared",
"-o", libname,
"-L", libcuda_dirs()[0]],
check=True, cwd=dir)
subprocess.run(
["gcc"] + o_files + ["-shared", "-o", libname, "-L", libcuda_dirs()[0]],
check=True,
cwd=dir,
)
def gen_test_bin(dir, M, N, K, exe="test", algo_id=0):
test_src = f'''
test_src = f"""
int main(int argc, char **argv) {{
int M = {M}, N = {N}, K = {K};
@@ -165,17 +168,29 @@ int main(int argc, char **argv) {{
cuMemFree(C);
cuCtxDestroy(ctx);
}}
'''
"""
src = test_utils_src + test_src
with open(os.path.join(dir, "test.c"), "w") as file:
file.write(src)
subprocess.run(["gcc"] + ["test.c",
"-I", cuda_include_dir(),
"-L", libcuda_dirs()[0],
"-l", "cuda",
"-L", dir,
"-l", "kernel",
"-o", exe], check=True, cwd=dir)
subprocess.run(
["gcc"] + [
"test.c",
"-I",
cuda_include_dir(),
"-L",
libcuda_dirs()[0],
"-l",
"cuda",
"-L",
dir,
"-l",
"kernel",
"-o",
exe,
],
check=True,
cwd=dir,
)
def write_triton_kernels(dir, src, util_src):
@@ -190,16 +205,67 @@ def write_triton_kernels(dir, src, util_src):
return kernel_path
def compile_aot_kernels(dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints):
def _compile_kernel(dir, signature, kernel_name, out_name, out_path, num_warps, grid, kernel_path):
compiler_path = os.path.join(triton.tools.__path__[0], "compile.py")
subprocess.run(
[
sys.executable,
compiler_path,
"-n",
kernel_name,
"--signature",
signature,
"--out-name",
out_name,
"-o",
out_path,
"-w",
str(num_warps),
"-g",
grid,
kernel_path,
],
check=True,
cwd=dir,
)
# Edge case kernel with no specialization
def compile_aot_kernel_no_specialization(dir, kernel_path, dtype, BM, BN, BK):
# compile all desired configs
sig = f"*fp32, *{dtype}, *{dtype}, i32, i32, i32, i32, i32, i32, i32, i32, i32, {BM}, {BN}, {BK}"
name = f"matmul_{dtype}"
grid = f"M/{BM}, N/{BN}, 1"
_compile_kernel(
dir=dir,
signature=sig,
kernel_name="kernel",
out_name=name,
out_path=name,
num_warps=1,
grid=grid,
kernel_path=kernel_path,
)
def compile_aot_kernels(dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints):
# compile all desired configs
for ha in ha_hb_hints:
for hb in ha_hb_hints:
sig = f'*fp32:16, *{dtype}:16, *{dtype}:16, i32, i32, i32, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}'
sig = f"*fp32:16, *{dtype}:16, *{dtype}:16, i32, i32, i32, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}"
name = f"matmul_{dtype}"
grid = f'M/{BM}, N/{BN}, 1'
subprocess.run([sys.executable, compiler_path, "-n", "kernel", "--signature", sig, "--out-name", name, "-o", name, "-w", "1", "-g", grid, kernel_path], check=True, cwd=dir)
grid = f"M/{BM}, N/{BN}, 1"
_compile_kernel(
dir=dir,
signature=sig,
kernel_name="kernel",
out_name=name,
out_path=name,
num_warps=1,
grid=grid,
kernel_path=kernel_path,
)
def link_aot_kernels(dir):
@@ -221,11 +287,42 @@ def generate_matmul_test_data(dir, M, N, K):
return a, b, a_path, b_path, c_path
# Test edge case where the provided kernel signature has no specializations
def test_compile_link_matmul_no_specialization():
np.random.seed(3)
with tempfile.TemporaryDirectory() as tmp_dir:
dtype = "fp16"
BM, BN, BK = 16, 16, 16
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
compile_aot_kernel_no_specialization(tmp_dir, kernel_path, dtype, BM, BN, BK)
link_aot_kernels(tmp_dir)
# compile test case
M, N, K = 16, 16, 16
gen_kernel_library(tmp_dir, "libkernel.so")
gen_test_bin(tmp_dir, M, N, K)
# initialize test data
a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K)
# run test case
env = os.environ.copy()
env["LD_LIBRARY_PATH"] = tmp_dir
subprocess.run(["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir)
# read data and compare against reference
c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32)
c_tri = c.reshape((M, N)).view(np.float32)
c_ref = np.matmul(a.astype(np.float32), b.astype(np.float32))
np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=0.0)
def test_compile_link_matmul():
np.random.seed(3)
with tempfile.TemporaryDirectory() as tmp_dir:
dtype = "fp16"
BM, BN, BK = 16, 16, 16
@@ -250,7 +347,7 @@ def test_compile_link_matmul():
c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32)
c_tri = c.reshape((M, N)).view(np.float32)
c_ref = np.matmul(a.astype(np.float32), b.astype(np.float32))
np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=0.)
np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=0.0)
def test_launcher_has_no_available_kernel():
@@ -275,7 +372,13 @@ def test_launcher_has_no_available_kernel():
# run test case
env = os.environ.copy()
env["LD_LIBRARY_PATH"] = tmp_dir
result = subprocess.run(["./test", a_path, b_path, c_path], env=env, cwd=tmp_dir, capture_output=True, text=True)
result = subprocess.run(
["./test", a_path, b_path, c_path],
env=env,
cwd=tmp_dir,
capture_output=True,
text=True,
)
# It should fail since the launcher requires all the strides be 1 while they are not.
assert result.returncode == -6
@@ -286,7 +389,6 @@ def test_compile_link_autotune_matmul():
np.random.seed(3)
with tempfile.TemporaryDirectory() as tmp_dir:
dtype = "fp16"
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
@@ -319,7 +421,12 @@ def test_compile_link_autotune_matmul():
env = os.environ.copy()
env["LD_LIBRARY_PATH"] = tmp_dir
subprocess.run([f"./{test_name}", a_path, b_path, c_path], check=True, cwd=tmp_dir, env=env)
subprocess.run(
[f"./{test_name}", a_path, b_path, c_path],
check=True,
cwd=tmp_dir,
env=env,
)
# read data and compare against reference
c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32)

View File

@@ -45,12 +45,12 @@ __all__ = [
"tools",
]
# -------------------------------------
# misc. utilities that don't fit well
# into any specific module
# -------------------------------------
def cdiv(x: int, y: int):
return (x + y - 1) // y

View File

@@ -1,5 +1,5 @@
import functools
import hashlib
import importlib
import importlib.util
import os
@@ -10,8 +10,12 @@ from typing import Dict
from ..runtime.driver import DriverBase
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
TRITON_VERSION = "2.1.0"
class BaseBackend:
def __init__(self, device_type: str) -> None:
self.device_type = device_type
@@ -104,7 +108,7 @@ def get_backend(device_type: str):
def _path_to_binary(binary: str):
base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
paths = [
os.environ.get("TRITON_PTXAS_PATH", ""),
os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
os.path.join(base_dir, "third_party", "cuda", "bin", binary)
]
@@ -132,3 +136,48 @@ def path_to_cuobjdump():
@functools.lru_cache()
def path_to_nvdisasm():
return _path_to_binary("nvdisasm")
@functools.lru_cache()
def compute_core_version_key():
import pkgutil
contents = []
# frontend
with open(__file__, "rb") as f:
contents += [hashlib.sha1(f.read()).hexdigest()]
# compiler
compiler_path = os.path.join(TRITON_PATH, 'compiler')
for lib in pkgutil.iter_modules([compiler_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.sha1(f.read()).hexdigest()]
# backend
libtriton_hash = hashlib.sha1()
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
while True:
chunk = f.read(1024**2)
if not chunk:
break
libtriton_hash.update(chunk)
contents.append(libtriton_hash.hexdigest())
# language
language_path = os.path.join(TRITON_PATH, 'language')
for lib in pkgutil.iter_modules([language_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.sha1(f.read()).hexdigest()]
return '-'.join(TRITON_VERSION) + '-'.join(contents)
_cached_cuda_version_key = None
def get_cuda_version_key():
global _cached_cuda_version_key
if _cached_cuda_version_key is None:
key = compute_core_version_key()
try:
ptxas = path_to_ptxas()[0]
ptxas_version = subprocess.check_output([ptxas, "--version"])
except RuntimeError:
ptxas_version = b"NO_PTXAS"
_cached_cuda_version_key = key + '-' + hashlib.sha1(ptxas_version).hexdigest()
return _cached_cuda_version_key

View File

@@ -92,9 +92,15 @@ def _build(name, src, srcdir):
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
if is_hip():
ret = subprocess.check_call([cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{hip_lib_dir}", "-lamdhip64", "-o", so])
ret = subprocess.check_call([
cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC",
f"-L{hip_lib_dir}", "-lamdhip64", "-o", so
])
else:
cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]
cc_cmd = [
cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda",
"-o", so
]
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
ret = subprocess.check_call(cc_cmd)

View File

@@ -1,5 +1,8 @@
from .compiler import (CompiledKernel, compile, get_arch_default_num_stages,
get_arch_default_num_warps, instance_descriptor)
from .compiler import (CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps,
instance_descriptor)
from .errors import CompilationError
__all__ = ["compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps", "get_arch_default_num_stages"]
__all__ = [
"compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps",
"get_arch_default_num_stages"
]

View File

@@ -10,8 +10,7 @@ from .._C.libtriton.triton import ir
from ..language import constexpr, tensor
# ideally we wouldn't need any runtime component
from ..runtime import JITFunction
from .errors import (CompilationError, CompileTimeAssertionFailure,
UnsupportedLanguageConstruct)
from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
def mangle_ty(ty):
@@ -68,7 +67,10 @@ def _check_fn_args(node, fn, args):
if fn.noinline:
for idx, arg in enumerate(args):
if not _is_constexpr(arg) and not _is_triton_scalar(arg):
raise UnsupportedLanguageConstruct(fn.src, node, f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}')
raise UnsupportedLanguageConstruct(
fn.src, node,
f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}'
)
def _get_fn_file_line(fn):
@@ -89,6 +91,7 @@ _condition_types = {bool, int, type(None)} # Python types accepted for conditio
class enter_sub_region:
def __init__(self, generator):
self.generator = generator
@@ -109,6 +112,7 @@ class enter_sub_region:
# Check if the given syntax node has an "early" return
class ContainsReturnChecker(ast.NodeVisitor):
def __init__(self, gscope):
self.gscope = gscope
@@ -199,9 +203,10 @@ class ContainsReturnChecker(ast.NodeVisitor):
class CodeGenerator(ast.NodeVisitor):
def __init__(self, context, prototype, gscope, attributes, constants, function_name, target,
module=None, is_kernel=False, function_types: Optional[Dict] = None,
debug=False, noinline=False, file_name: Optional[str] = None, begin_line=0):
def __init__(self, context, prototype, gscope, attributes, constants, function_name, target, module=None,
is_kernel=False, function_types: Optional[Dict] = None, debug=False, noinline=False,
file_name: Optional[str] = None, begin_line=0):
self.context = context
self.builder = ir.builder(context)
self.file_name = file_name
@@ -237,8 +242,10 @@ class CodeGenerator(ast.NodeVisitor):
))
def _define_name_lookup(self):
def local_lookup(name: str, absent):
value = self.lscope.get(name, absent) # this needs to be re-fetched from `self` every time, because it gets switched occasionally
# this needs to be re-fetched from `self` every time, because it gets switched occasionally
value = self.lscope.get(name, absent)
if value is not absent and name not in self.local_defs:
self.global_uses[name] = value
return value
@@ -255,8 +262,7 @@ class CodeGenerator(ast.NodeVisitor):
return name_lookup
def set_value(self, name: str,
value: Union[tensor, constexpr]) -> None:
def set_value(self, name: str, value: Union[tensor, constexpr]) -> None:
''' This function:
called by visit_Assign() & visit_FunctionDef() to store left value (lvalue)
1. record local defined name (FIXME: should consider control flow)
@@ -338,7 +344,8 @@ class CodeGenerator(ast.NodeVisitor):
self.visit(init_node)
# initialize function
visibility = "public" if self.is_kernel else "private"
self.fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility, self.noinline)
self.fn = self.builder.get_or_insert_function(self.module, self.function_name,
self.prototype.to_ir(self.builder), visibility, self.noinline)
self.module.push_back(self.fn)
entry = self.fn.add_entry_block()
arg_values = []
@@ -469,12 +476,23 @@ class CodeGenerator(ast.NodeVisitor):
rhs = self.visit(node.right)
method_name = self._method_name_for_bin_op.get(type(node.op))
if method_name is None:
raise UnsupportedLanguageConstruct(None, node, "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__))
raise UnsupportedLanguageConstruct(
None, node, "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__))
return self._apply_binary_method(method_name, lhs, rhs)
_method_name_for_bin_op: Dict[Type[ast.operator], str] = {
ast.Add: '__add__', ast.Sub: '__sub__', ast.Mult: '__mul__', ast.Div: '__truediv__',
ast.FloorDiv: '__floordiv__', ast.Mod: '__mod__', ast.Pow: '__pow__',
ast.LShift: '__lshift__', ast.RShift: '__rshift__', ast.BitAnd: '__and__', ast.BitOr: '__or__', ast.BitXor: '__xor__',
ast.Add: '__add__',
ast.Sub: '__sub__',
ast.Mult: '__mul__',
ast.Div: '__truediv__',
ast.FloorDiv: '__floordiv__',
ast.Mod: '__mod__',
ast.Pow: '__pow__',
ast.LShift: '__lshift__',
ast.RShift: '__rshift__',
ast.BitAnd: '__and__',
ast.BitOr: '__or__',
ast.BitXor: '__xor__',
}
def visit_then_else_blocks(self, node, liveins, then_block, else_block):
@@ -508,7 +526,8 @@ class CodeGenerator(ast.NodeVisitor):
if name in then_defs or name in else_defs:
names.append(name)
ret_types.append(then_defs[name].type if name in then_defs else else_defs[name].type)
ir_ret_types.append(then_defs[name].handle.get_type() if name in then_defs else else_defs[name].handle.get_type())
ir_ret_types.append(then_defs[name].handle.get_type() if name in
then_defs else else_defs[name].handle.get_type())
# variable defined in then but not in else
if name in then_defs and name not in else_defs:
else_defs[name] = liveins[name]
@@ -602,8 +621,7 @@ class CodeGenerator(ast.NodeVisitor):
contains_return = ContainsReturnChecker(self.gscope).visit(node)
if self.scf_stack and contains_return:
raise UnsupportedLanguageConstruct(
None, node,
"Cannot have `return` statements inside `while` or `for` statements in triton "
None, node, "Cannot have `return` statements inside `while` or `for` statements in triton "
"(note that this also applies to `return` statements that are inside functions "
"transitively called from within `while`/`for` statements)")
elif self.scf_stack or not contains_return:
@@ -612,10 +630,13 @@ class CodeGenerator(ast.NodeVisitor):
self.visit_if_top_level(cond, node)
else:
cond = _unwrap_if_constexpr(cond)
if type(cond) not in _condition_types: # not isinstance - we insist the real thing, no subclasses and no ducks
# not isinstance - we insist the real thing, no subclasses and no ducks
if type(cond) not in _condition_types:
raise UnsupportedLanguageConstruct(
None, node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
', '.join(_.__name__ for _ in _condition_types), type(cond).__name__))
None, node,
"`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
', '.join(_.__name__ for _ in _condition_types),
type(cond).__name__))
if cond:
self.visit_compound_statement(node.body)
else:
@@ -624,15 +645,52 @@ class CodeGenerator(ast.NodeVisitor):
def visit_IfExp(self, node):
cond = self.visit(node.test)
if _is_triton_tensor(cond):
raise UnsupportedLanguageConstruct(
None, node,
"Triton does not support `if` expressions (ternary operators) with dynamic conditions, use `if` statements instead")
cond = cond.to(language.int1, _builder=self.builder)
# TODO: Deal w/ more complicated return types (e.g tuple)
with enter_sub_region(self):
ip, last_loc = self._get_insertion_point_and_loc()
then_block = self.builder.create_block()
self.builder.set_insertion_point_to_start(then_block)
then_val = language.core._to_tensor(self.visit(node.body), self.builder)
then_block = self.builder.get_insertion_block()
else_block = self.builder.create_block()
self.builder.set_insertion_point_to_start(else_block)
# do not need to reset lscope since
# ternary expressions cannot define new variables
else_val = language.core._to_tensor(self.visit(node.orelse), self.builder)
else_block = self.builder.get_insertion_block()
self._set_insertion_point_and_loc(ip, last_loc)
assert then_val.type == else_val.type, \
f'ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}'
ret_type = then_val.type
ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else []
if_op = self.builder.create_if_op(ret_type_ir, cond.handle, True)
then_block.merge_block_before(if_op.get_then_block())
if ret_type_ir:
self.builder.set_insertion_point_to_end(if_op.get_then_block())
self.builder.create_yield_op([then_val.handle])
self.builder.set_insertion_point_to_end(if_op.get_then_block())
else_block.merge_block_before(if_op.get_else_block())
if ret_type_ir:
self.builder.set_insertion_point_to_end(if_op.get_else_block())
self.builder.create_yield_op([else_val.handle])
return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None
else:
cond = _unwrap_if_constexpr(cond)
if type(cond) not in _condition_types: # not isinstance - we insist the real thing, no subclasses and no ducks
# not isinstance - we insist the real thing, no subclasses and no ducks
if type(cond) not in _condition_types:
raise UnsupportedLanguageConstruct(
None, node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
', '.join(_.__name__ for _ in _condition_types), type(cond).__name__))
None, node,
"`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
', '.join(_.__name__ for _ in _condition_types),
type(cond).__name__))
if cond:
return self.visit(node.body)
else:
@@ -654,8 +712,10 @@ class CodeGenerator(ast.NodeVisitor):
return constexpr(lhs_value is not rhs_value)
method_name = self._method_name_for_comp_op.get(type(node.ops[0]))
if method_name is None:
raise UnsupportedLanguageConstruct(None, node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__))
raise UnsupportedLanguageConstruct(
None, node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__))
return self._apply_binary_method(method_name, lhs, rhs)
_method_name_for_comp_op: Dict[Type[ast.cmpop], str] = {
ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__'
}
@@ -664,11 +724,15 @@ class CodeGenerator(ast.NodeVisitor):
op = self.visit(node.operand)
fn = self._method_name_for_unary_op.get(type(node.op))
if fn is None:
raise UnsupportedLanguageConstruct(None, node, "AST unary operator '{}' is not (currently) implemented.".format(node.op.__name__))
raise UnsupportedLanguageConstruct(
None, node, "AST unary operator '{}' is not (currently) implemented.".format(node.op.__name__))
if _is_triton_tensor(op):
return getattr(op, fn)(_builder=self.builder)
return getattr(op, fn)()
_method_name_for_unary_op: Dict[Type[ast.unaryop], str] = {ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__'}
_method_name_for_unary_op: Dict[Type[ast.unaryop], str] = {
ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__'
}
def visit_While(self, node):
with enter_sub_region(self) as sr:
@@ -763,9 +827,7 @@ class CodeGenerator(ast.NodeVisitor):
iter_args = [self.visit(arg) for arg in node.iter.args]
if IteratorClass == language.static_range:
iterator = IteratorClass(*iter_args)
static_range = range(iterator.start.value,
iterator.end.value,
iterator.step.value)
static_range = range(iterator.start.value, iterator.end.value, iterator.step.value)
for i in static_range:
self.lscope[node.target.id] = constexpr(i)
self.visit_compound_statement(node.body)
@@ -902,8 +964,7 @@ class CodeGenerator(ast.NodeVisitor):
def call_JitFunction(self, fn: JITFunction, args, kwargs):
args = inspect.getcallargs(fn.fn, *args, **kwargs)
args = [args[name] for name in fn.arg_names]
args = [arg if _is_triton_tensor(arg)
else constexpr(arg) for arg in args]
args = [arg if _is_triton_tensor(arg) else constexpr(arg) for arg in args]
# generate function def
attributes = dict()
constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)]
@@ -921,8 +982,9 @@ class CodeGenerator(ast.NodeVisitor):
debug = self.debug if fn.debug is None else fn.debug
file_name, begin_line = _get_fn_file_line(fn)
generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module,
function_name=fn_name, function_types=self.function_ret_types, debug=debug, noinline=fn.noinline,
file_name=file_name, begin_line=begin_line, target=self.builder.target)
function_name=fn_name, function_types=self.function_ret_types, debug=debug,
noinline=fn.noinline, file_name=file_name, begin_line=begin_line,
target=self.builder.target)
generator.visit(fn.parse())
callee_ret_type = generator.last_ret_type
self.function_ret_types[fn_name] = callee_ret_type
@@ -950,7 +1012,7 @@ class CodeGenerator(ast.NodeVisitor):
kws = dict(self.visit(keyword) for keyword in node.keywords)
args = [self.visit(arg) for arg in node.args]
if fn is language.core.device_assert: # TODO: this should not be so hardcoded
if fn is language.core.device_assert: # TODO: this should not be so hardcoded
if not self.debug:
return
if isinstance(fn, JITFunction):
@@ -971,16 +1033,21 @@ class CodeGenerator(ast.NodeVisitor):
def visit_BoolOp(self, node: ast.BoolOp):
if len(node.values) != 2:
raise UnsupportedLanguageConstruct(None, node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.")
raise UnsupportedLanguageConstruct(
None, node,
"chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.")
lhs = self.visit(node.values[0])
rhs = self.visit(node.values[1])
method_name = self._method_name_for_bool_op.get(type(node.op))
if method_name is None:
raise UnsupportedLanguageConstruct(None, node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__))
raise UnsupportedLanguageConstruct(
None, node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__))
return self._apply_binary_method(method_name, lhs, rhs)
_method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'}
if sys.version_info < (3, 8):
def visit_NameConstant(self, node):
return constexpr(node.value)
@@ -1013,7 +1080,9 @@ class CodeGenerator(ast.NodeVisitor):
evaluated = self.visit(value.value)
if not _is_constexpr(evaluated):
raise UnsupportedLanguageConstruct(
None, node, "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type " + str(type(evaluated)))
None, node,
"Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type "
+ str(type(evaluated)))
values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value)
else:
raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value)))
@@ -1055,7 +1124,9 @@ class CodeGenerator(ast.NodeVisitor):
passed = _unwrap_if_constexpr(self.visit(node.args[0]))
if not isinstance(passed, bool):
raise NotImplementedError("Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values")
raise NotImplementedError(
"Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values"
)
if not passed:
if arg_count == 1:
message = ""
@@ -1144,10 +1215,9 @@ def ast_to_ttir(fn, signature, specialization, constants, debug, target):
file_name, begin_line = _get_fn_file_line(fn)
prototype = language.function_type([], arg_types)
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants,
function_name=function_name, attributes=new_attrs,
is_kernel=True, debug=debug, file_name=file_name, begin_line=begin_line,
target=target)
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name,
attributes=new_attrs, is_kernel=True, debug=debug, file_name=file_name,
begin_line=begin_line, target=target)
try:
generator.visit(fn.parse())
except CompilationError as e:

View File

@@ -11,25 +11,21 @@ from typing import Any
from dataclasses import dataclass
from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs,
compile_ptx_to_cubin, get_env_vars, get_num_warps,
get_shared_memory_size, ir, runtime,
translate_llvmir_to_ptx,
from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs, compile_ptx_to_cubin, get_env_vars,
get_num_warps, get_shared_memory_size, ir, runtime, translate_llvmir_to_ptx,
translate_triton_gpu_to_llvmir)
from ..common.backend import get_backend, path_to_ptxas
from ..common.backend import get_backend, get_cuda_version_key, path_to_ptxas
from ..common.build import is_hip
# from ..runtime import driver, jit, JITFunction
# TODO: runtime.errors
from ..runtime.autotuner import OutOfResources
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
from ..runtime.driver import driver
from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device,
get_device_capability, version_key)
from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device, get_device_capability)
from ..tools.disasm import get_sass
from .code_generator import ast_to_ttir
from .make_launcher import make_stub
from .utils import (InfoFromBackendForTensorMap, TensorMapManager,
get_ids_of_tensormaps, parse_tma_info)
from .utils import (InfoFromBackendForTensorMap, TensorMapManager, get_ids_of_tensormaps, parse_tma_info)
CUDA_DEFAULT_WARP_SIZE = 32
@@ -45,6 +41,7 @@ def _is_cuda(target):
class LazyDict(dict):
def __getitem__(self, key):
val = dict.__getitem__(self, key)
if callable(val):
@@ -103,8 +100,13 @@ def ttir_to_ttgir(mod, num_warps, warpsize, num_ctas, target):
return mod
<<<<<<< HEAD
def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target,
cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, matrix_inst_type):
=======
def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization,
enable_persistent, optimize_epilogue):
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
is_cuda = _is_cuda(target)
if is_cuda:
capability = target.capability
@@ -128,9 +130,13 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target,
if optimize_epilogue:
pm.add_tritongpu_optimize_epilogue_pass()
pm.add_tritongpu_optimize_dot_operands_pass()
<<<<<<< HEAD
if num_stages == 0 and is_hip() and gpu_matrix_core_version() != 0:
pm.add_tritongpu_stream_pipeline_pass()
pm.add_canonicalizer_pass()
=======
pm.add_cse_pass()
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
ws_enabled = False
# `num_warps` does not mean the total number of warps of a CTA when
# warp specialization is enabled.
@@ -174,6 +180,8 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target,
if is_cuda and capability // 10 >= 9:
pm.add_tritongpu_fence_insertion_pass()
pm.add_tritongpu_ws_fixup_missing_attrs_pass()
pm.add_tritongpu_optimize_thread_locality_pass()
pm.add_canonicalizer_pass()
pm.run(mod)
return mod
@@ -197,6 +205,7 @@ def ttgir_to_llir(mod, extern_libs, target, tma_infos, waves_per_eu=0):
# PTX translation
@functools.lru_cache()
def ptx_get_version(cuda_version) -> int:
'''
@@ -261,7 +270,11 @@ def convert_type_repr(x):
return x
def make_hash(fn, target, env_vars, **kwargs):
def make_hash(fn, target, env_vars, device_backend, **kwargs):
if device_backend is None:
version_key = get_cuda_version_key()
else:
version_key = device_backend.get_version_key()
if isinstance(fn, JITFunction):
configs = kwargs["configs"]
signature = kwargs["signature"]
@@ -275,16 +288,21 @@ def make_hash(fn, target, env_vars, **kwargs):
enable_persistent = kwargs.get("enable_persistent", False)
debug = kwargs.get("debug", False)
# Get unique key for the compiled code
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1), sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8))
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1),
sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8))
configs_key = [get_conf_key(conf) for conf in configs]
env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())]
<<<<<<< HEAD
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{matrix_instr_nonkdim}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}"
=======
key = f"{fn.cache_key}-{version_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}"
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
return hashlib.md5(key.encode("utf-8")).hexdigest()
assert isinstance(fn, str)
ignore_version = kwargs.get('ignore_version', False)
if (ignore_version):
return hashlib.md5((Path(fn).read_text()).encode("utf-8")).hexdigest()
return hashlib.md5((Path(fn).read_text() + version_key()).encode("utf-8")).hexdigest()
return hashlib.md5((Path(fn).read_text() + version_key).encode("utf-8")).hexdigest()
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
@@ -321,12 +339,14 @@ else:
def _get_jsonable_constants(constants):
def _is_jsonable(x):
try:
json.dumps(x)
return True
except (TypeError, OverflowError):
return False
serialized_constants = {}
for constant in constants:
if _is_jsonable(constants[constant]):
@@ -341,7 +361,9 @@ def parse_mlir_module(path, context):
return module
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"], defaults=[set(), set(), set(), set()])
instance_descriptor = namedtuple("instance_descriptor",
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"],
defaults=[set(), set(), set(), set()])
def is_hip():
@@ -385,10 +407,16 @@ def get_arch_default_num_stages(device_type, capability=None):
def add_cuda_stages(target, extern_libs, stages):
<<<<<<< HEAD
stages["ptx"] = (lambda path: Path(path).read_text(),
lambda src: llir_to_ptx(src, target))
stages["cubin"] = (lambda path: Path(path).read_bytes(),
lambda src: ptx_to_cubin(src, target))
=======
stages["ptx"] = (lambda path: Path(path).read_text(), lambda src: llir_to_ptx(src, target))
stages["cubin"] = (lambda path: Path(path).read_bytes(), lambda src: ptx_to_cubin(src, target))
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
def compile(fn, **kwargs):
@@ -434,7 +462,8 @@ def compile(fn, **kwargs):
# build architecture descriptor
if device_type == "cuda":
_device_backend = get_backend(device_type)
target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps, enable_fp_fusion=enable_fp_fusion)
target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps,
enable_fp_fusion=enable_fp_fusion)
else:
_device_backend = get_backend(device_type)
assert _device_backend
@@ -443,11 +472,12 @@ def compile(fn, **kwargs):
# build compilation stages
stages = dict()
stages["ast"] = (lambda path: fn, None)
stages["ttir"] = (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target))
stages["ttir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttir(
ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target))
if is_cuda:
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
stages["ttgir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttgir(
ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info,
enable_warp_specialization, enable_persistent, optimize_epilogue))
stages["llir"] = (lambda path: Path(path).read_text(),
lambda src: ttgir_to_llir(src, extern_libs, target, tma_infos))
add_cuda_stages(target, extern_libs, stages)
@@ -507,18 +537,21 @@ def compile(fn, **kwargs):
if ir_name == 'ttgir':
num_warps_matches = re.findall(ttgir_num_warps_pattern, src)
assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps"
assert "num_warps" not in kwargs or int(num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile"
assert "num_warps" not in kwargs or int(
num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile"
num_warps = int(num_warps_matches[0])
param_tys = [convert_type_repr(ty) for ty in types]
signature = {k: v for k, v in enumerate(param_tys)}
first_stage = list(stages.keys()).index(ir_name)
# create cache manager
fn_cache_manager = get_cache_manager(make_hash(fn, target, get_env_vars(), **kwargs))
fn_cache_manager = get_cache_manager(make_hash(fn, target, get_env_vars(), _device_backend, **kwargs))
# managers used to dump and override IR for debugging
enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1"
fn_override_manager = get_override_manager(make_hash(fn, target, get_env_vars(), **kwargs, ignore_version=True))
fn_dump_manager = get_dump_manager(make_hash(fn, target, get_env_vars(), **kwargs, ignore_version=True))
fn_override_manager = get_override_manager(
make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True))
fn_dump_manager = get_dump_manager(
make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True))
# determine name and extension type of provided function
if isinstance(fn, JITFunction):
@@ -531,9 +564,7 @@ def compile(fn, **kwargs):
metadata_filename = f"{name}.json"
# The group is addressed by the metadata
metadata_group = fn_cache_manager.get_group(
metadata_filename
) or {}
metadata_group = fn_cache_manager.get_group(metadata_filename) or {}
metadata_path = metadata_group.get(metadata_filename)
@@ -541,9 +572,9 @@ def compile(fn, **kwargs):
with open(metadata_path) as f:
metadata = json.load(f)
if 'tensormaps_info' in metadata:
metadata['tensormaps_info'] = [
InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']]
metadata['tensormaps_info'] = [InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']]
else:
<<<<<<< HEAD
metadata = {"num_warps": num_warps,
"warp_size": warp_size,
"num_ctas": num_ctas,
@@ -555,6 +586,18 @@ def compile(fn, **kwargs):
"constants": _get_jsonable_constants(constants),
"debug": debug,
"target": target, }
=======
metadata = {
"num_warps": num_warps,
"num_ctas": num_ctas,
"num_stages": num_stages,
"enable_warp_specialization": enable_warp_specialization,
"enable_persistent": enable_persistent,
"constants": _get_jsonable_constants(constants),
"debug": debug,
"target": target,
}
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
metadata.update(get_env_vars())
if ext == "ptx":
assert "shared" in kwargs, "ptx compilation must provide shared memory size"
@@ -626,10 +669,7 @@ def compile(fn, **kwargs):
ids_of_folded_args = tuple([int(k) for k in configs[0].ids_of_folded_args]) if isinstance(fn, JITFunction) else ()
if "clusterDims" not in metadata:
metadata["clusterDims"] = [
cluster_info.clusterDimX,
cluster_info.clusterDimY,
cluster_info.clusterDimZ]
metadata["clusterDims"] = [cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ]
if len(tma_infos) > 0:
metadata["tensormaps_info"] = parse_tma_info(tma_infos, ids_of_folded_args)
@@ -643,7 +683,10 @@ def compile(fn, **kwargs):
fn.tensormaps_info = metadata["tensormaps_info"]
ids_of_const_exprs = tuple(fn.constexprs) if isinstance(fn, JITFunction) else ()
ids = {"ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs": ids_of_const_exprs}
ids = {
"ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs":
ids_of_const_exprs
}
# cache manager
if is_cuda:
so_path = make_stub(name, signature, constants, ids, enable_warp_specialization=enable_warp_specialization)
@@ -651,7 +694,8 @@ def compile(fn, **kwargs):
so_path = _device_backend.make_launcher_stub(name, signature, constants, ids)
# write-back metadata, if it didn't come from the cache
if metadata_path is None:
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, binary=False)
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename,
binary=False)
fn_cache_manager.put_group(metadata_filename, metadata_group)
# return handle to compiled kernel
@@ -701,10 +745,7 @@ class CompiledKernel:
if self.device_type in ["cuda"]:
device = get_current_device()
bin_path = {
driver.HIP: "hsaco_path",
driver.CUDA: "cubin"
}[driver.backend]
bin_path = {driver.HIP: "hsaco_path", driver.CUDA: "cubin"}[driver.backend]
max_shared = driver.utils.get_device_properties(device)["max_shared_mem"]
fn_load_binary = driver.utils.load_binary
else:
@@ -752,4 +793,5 @@ class CompiledKernel:
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.num_ctas, self.clusterDims[0],
self.clusterDims[1], self.clusterDims[2], self.shared, stream, self.cu_function,
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand)
return runner

View File

@@ -3,9 +3,9 @@ import os
import tempfile
from ..common import _build
from ..common.backend import get_cuda_version_key
from ..common.build import is_hip
from ..runtime.cache import get_cache_manager
from ..runtime.jit import version_key
from .utils import generate_cu_signature
# ----- stub --------
@@ -23,7 +23,7 @@ def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
def make_stub(name, signature, constants, ids, **kwargs):
# name of files that are cached
so_cache_key = make_so_cache_key(version_key(), signature, constants, ids, **kwargs)
so_cache_key = make_so_cache_key(get_cuda_version_key(), signature, constants, ids, **kwargs)
so_cache_manager = get_cache_manager(so_cache_key)
so_name = f"{name}.so"
# retrieve stub from cache if it exists
@@ -40,6 +40,7 @@ def make_stub(name, signature, constants, ids, **kwargs):
else:
return cache_path
# ----- source code generation --------
@@ -100,7 +101,10 @@ def generate_launcher(constants, signature, ids):
# generate glue code
folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']]
params = [i for i in signature.keys() if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs)]
params = [
i for i in signature.keys()
if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs)
]
src = f"""
#include \"cuda.h\"
#include <stdbool.h>

View File

@@ -158,19 +158,21 @@ class InfoFromBackendForTensorMap:
# dtype:cuda.CUtensorMapDataType | int
def bytes_from_type(self, dtype):
return {driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT8"]: 1,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT16"]: 2,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT32"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT32"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT64"]: 8,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT64"]: 8,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT16"]: 2,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT64"]: 8,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_BFLOAT16"]: 2,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ"]: 4}[dtype]
return {
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT8"]: 1,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT16"]: 2,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT32"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT32"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT64"]: 8,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT64"]: 8,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT16"]: 2,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT64"]: 8,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_BFLOAT16"]: 2,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ"]: 4
}[dtype]
def getTensorMapDataType(self):
return self.tensorDataType
@@ -259,22 +261,29 @@ class InfoFromBackendForTensorMap:
self.getInterleave(),
self.getSwizzle(),
self.getL2Promotion(),
self.getOobFill()
self.getOobFill(),
)
# make hashable to use as partial key in cache
def __hash__(self):
return hash((self.ids_of_folded_args, self.globalAddressArgIdx, tuple(self.globalDimsArgIdx), tuple(self.globalStridesArgIdx), self.tensorDataType,
self.tensorRank, tuple(self.boxDims), tuple(self.elementStrides), self.interleave, self.swizzle, self.l2Promotion, self.oobFill))
return hash((self.ids_of_folded_args, self.globalAddressArgIdx, tuple(self.globalDimsArgIdx),
tuple(self.globalStridesArgIdx), self.tensorDataType, self.tensorRank, tuple(self.boxDims),
tuple(self.elementStrides), self.interleave, self.swizzle, self.l2Promotion, self.oobFill))
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
return (self.ids_of_folded_args, self.globalAddressArgIdx, self.globalDimsArgIdx, self.globalStridesArgIdx, self.tensorDataType, self.tensorRank, self.boxDims, self.elementStrides, self.interleave, self.swizzle, self.l2Promotion, self.oobFill) == (
other.ids_of_folded_args, other.globalAddressArgIdx, other.globalDimsArgIdx, other.globalStridesArgIdx, other.tensorDataType, other.tensorRank, other.boxDims, other.elementStrides, other.interleave, other.swizzle, other.l2Promotion, other.oobFill)
return (self.ids_of_folded_args, self.globalAddressArgIdx, self.globalDimsArgIdx, self.globalStridesArgIdx,
self.tensorDataType, self.tensorRank, self.boxDims, self.elementStrides, self.interleave, self.swizzle,
self.l2Promotion,
self.oobFill) == (other.ids_of_folded_args, other.globalAddressArgIdx, other.globalDimsArgIdx,
other.globalStridesArgIdx, other.tensorDataType, other.tensorRank, other.boxDims,
other.elementStrides, other.interleave, other.swizzle, other.l2Promotion,
other.oobFill)
class TensorMapManager:
def __init__(self):
self.tensormaps_device = {}
@@ -286,8 +295,7 @@ class TensorMapManager:
t_tensormap = e.tensormap(args)
TENSORMAP_SIZE_IN_BYTES = 128
t_tensormap_device = driver.utils.cuMemAlloc(TENSORMAP_SIZE_IN_BYTES)
driver.utils.cuMemcpyHtoD(
t_tensormap_device, t_tensormap, TENSORMAP_SIZE_IN_BYTES)
driver.utils.cuMemcpyHtoD(t_tensormap_device, t_tensormap, TENSORMAP_SIZE_IN_BYTES)
self.tensormaps_device[key] = t_tensormap_device
return int(self.tensormaps_device[key])

View File

@@ -111,7 +111,6 @@ from .random import (
uint32_to_uniform_float,
)
__all__ = [
"TRITON_MAX_TENSOR_NUMEL",
"abs",

View File

@@ -22,10 +22,8 @@ def builtin(fn: T) -> T:
@wraps(fn)
def wrapper(*args, **kwargs):
if "_builder" not in kwargs or kwargs["_builder"] is None:
raise ValueError(
"Did you forget to add @triton.jit ? "
"(`_builder` argument must be provided outside of JIT functions.)"
)
raise ValueError("Did you forget to add @triton.jit ? "
"(`_builder` argument must be provided outside of JIT functions.)")
return fn(*args, **kwargs)
setattr(wrapper, TRITON_BUILTIN, True)
@@ -54,7 +52,7 @@ def _to_tensor(x, builder):
else:
raise RuntimeError(f'Nonrepresentable integer {x}.')
elif isinstance(x, float):
min_float32 = 2 ** -126
min_float32 = 2**-126
max_float32 = (2 - 2**-23) * 2**127
abs_x = __builtins__['abs'](x)
if abs_x == float("inf") or\
@@ -243,7 +241,7 @@ class dtype:
return not self.__eq__(other)
def __hash__(self):
return hash((self.name,))
return hash((self.name, ))
@property
def scalar(self):
@@ -297,6 +295,7 @@ class dtype:
class pointer_type(dtype):
def __init__(self, element_ty: dtype, address_space: int = 1):
if not isinstance(element_ty, dtype):
raise TypeError('element_ty is a {type(element_ty).__name__}.')
@@ -331,6 +330,7 @@ class pointer_type(dtype):
class block_type(dtype):
def __init__(self, element_ty: dtype, shape: List):
self.element_ty = element_ty
@@ -381,6 +381,7 @@ class block_type(dtype):
class function_type(dtype):
def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None:
self.ret_types = ret_types
self.param_types = param_types
@@ -531,7 +532,7 @@ class constexpr:
return constexpr(~self.value)
def __pow__(self, other):
return constexpr(self.value ** other.value)
return constexpr(self.value**other.value)
def __rshift__(self, other):
return constexpr(self.value >> other.value)
@@ -547,6 +548,7 @@ class constexpr:
class tensor:
def __init__(self, handle, type: dtype):
# IR handle
self.handle = handle
@@ -740,11 +742,21 @@ class tensor:
other = _to_tensor(other, _builder)
return semantic.equal(self, other, _builder)
@builtin
def __req__(self, other, _builder=None):
other = _to_tensor(other, _builder)
return semantic.equal(other, self, _builder)
@builtin
def __ne__(self, other, _builder=None):
other = _to_tensor(other, _builder)
return semantic.not_equal(self, other, _builder)
@builtin
def __rne__(self, other, _builder=None):
other = _to_tensor(other, _builder)
return semantic.not_equal(other, self, _builder)
@builtin
def logical_and(self, other, _builder=None):
other = _to_tensor(other, _builder)
@@ -1023,6 +1035,7 @@ def expand_dims(input, axis, _builder=None):
ret = semantic.expand_dims(ret, a, _builder)
return ret
# -----------------------
# Linear Algebra
# -----------------------
@@ -1171,6 +1184,7 @@ def advance(base: tensor, offsets, _builder=None):
"""
return semantic.advance(base, offsets, _builder)
# -----------------------
# Atomic Memory Operations
# -----------------------
@@ -1196,6 +1210,9 @@ def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]:
:param sem: Memory semantics to use ("ACQUIRE_RELEASE" (default),
"ACQUIRE", "RELEASE", or "RELAXED")
:type sem: str
:param scope: Scope of threads that observe synchronizing effect of the
atomic operation ("GPU" (default), "CTA", or "SYSTEM")
:type scope: str
"""
func.__doc__ = docstr
return func
@@ -1205,73 +1222,82 @@ def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]:
@builtin
@_add_atomic_docstr("compare-and-swap", has_cmp=True)
def atomic_cas(pointer, cmp, val, sem=None, _builder=None):
def atomic_cas(pointer, cmp, val, sem=None, scope=None, _builder=None):
cmp = _to_tensor(cmp, _builder)
val = _to_tensor(val, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_cas(pointer, cmp, val, sem, _builder)
scope = _constexpr_to_value(scope)
return semantic.atomic_cas(pointer, cmp, val, sem, scope, _builder)
@builtin
@_add_atomic_docstr("exchange")
def atomic_xchg(pointer, val, mask=None, sem=None, _builder=None):
def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _builder=None):
val = _to_tensor(val, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_xchg(pointer, val, mask, sem, _builder)
scope = _constexpr_to_value(scope)
return semantic.atomic_xchg(pointer, val, mask, sem, scope, _builder)
@builtin
@_add_atomic_docstr("add")
def atomic_add(pointer, val, mask=None, sem=None, _builder=None):
def atomic_add(pointer, val, mask=None, sem=None, scope=None, _builder=None):
val = _to_tensor(val, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_add(pointer, val, mask, sem, _builder)
scope = _constexpr_to_value(scope)
return semantic.atomic_add(pointer, val, mask, sem, scope, _builder)
@builtin
@_add_atomic_docstr("max")
def atomic_max(pointer, val, mask=None, sem=None, _builder=None):
def atomic_max(pointer, val, mask=None, sem=None, scope=None, _builder=None):
val = _to_tensor(val, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_max(pointer, val, mask, sem, _builder)
scope = _constexpr_to_value(scope)
return semantic.atomic_max(pointer, val, mask, sem, scope, _builder)
@builtin
@_add_atomic_docstr("min")
def atomic_min(pointer, val, mask=None, sem=None, _builder=None):
def atomic_min(pointer, val, mask=None, sem=None, scope=None, _builder=None):
val = _to_tensor(val, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_min(pointer, val, mask, sem, _builder)
scope = _constexpr_to_value(scope)
return semantic.atomic_min(pointer, val, mask, sem, scope, _builder)
@builtin
@_add_atomic_docstr("logical and")
def atomic_and(pointer, val, mask=None, sem=None, _builder=None):
def atomic_and(pointer, val, mask=None, sem=None, scope=None, _builder=None):
val = _to_tensor(val, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_and(pointer, val, mask, sem, _builder)
scope = _constexpr_to_value(scope)
return semantic.atomic_and(pointer, val, mask, sem, scope, _builder)
@builtin
@_add_atomic_docstr("logical or")
def atomic_or(pointer, val, mask=None, sem=None, _builder=None):
def atomic_or(pointer, val, mask=None, sem=None, scope=None, _builder=None):
val = _to_tensor(val, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_or(pointer, val, mask, sem, _builder)
scope = _constexpr_to_value(scope)
return semantic.atomic_or(pointer, val, mask, sem, scope, _builder)
@builtin
@_add_atomic_docstr("logical xor")
def atomic_xor(pointer, val, mask=None, sem=None, _builder=None):
def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _builder=None):
val = _to_tensor(val, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_xor(pointer, val, mask, sem, _builder)
scope = _constexpr_to_value(scope)
return semantic.atomic_xor(pointer, val, mask, sem, scope, _builder)
# -----------------------
# Conditioning
# -----------------------
@builtin
def where(condition, x, y, _builder=None):
"""
@@ -1299,6 +1325,7 @@ def where(condition, x, y, _builder=None):
# Math
# -----------------------
@builtin
def umulhi(x, y, _builder=None):
"""
@@ -1392,6 +1419,7 @@ def abs(x, _builder=None):
# Reductions
# -----------------------
def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]:
def _decorator(func: T) -> T:
@@ -1430,8 +1458,7 @@ def reduce(input, axis, combine_fn, _builder=None, _generator=None):
"""
if isinstance(input, tensor):
return reduce((input,), axis, combine_fn,
_builder=_builder, _generator=_generator)[0]
return reduce((input, ), axis, combine_fn, _builder=_builder, _generator=_generator)[0]
def make_combine_region(reduce_op):
in_scalar_tys = [t.type.scalar for t in input]
@@ -1441,14 +1468,14 @@ def reduce(input, axis, combine_fn, _builder=None, _generator=None):
with _insertion_guard(_builder):
param_types = [ty.to_ir(_builder) for ty in prototype.param_types]
block = _builder.create_block_with_parent(region, param_types)
args = [tensor(block.arg(i), ty)
for i, ty in enumerate(prototype.param_types)]
args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)]
results = _generator.call_JitFunction(combine_fn, args, kwargs={})
if isinstance(results, tensor):
handles = [results.handle]
else:
handles = [r.handle for r in results]
_builder.create_reduce_ret(*handles)
if axis is not None:
axis = _constexpr_to_value(axis)
return semantic.reduction(input, axis, make_combine_region, _builder)
@@ -1483,8 +1510,7 @@ def _reduce_with_indices(input, axis, combine_fn, _builder=None, _generator=None
index = expand_dims(index, axes_to_expand, _builder=_builder)
index = broadcast_to(index, input.shape, _builder=_builder)
rvalue, rindices = reduce((input, index), axis, combine_fn,
_builder=_builder, _generator=_generator)
rvalue, rindices = reduce((input, index), axis, combine_fn, _builder=_builder, _generator=_generator)
return rvalue, rindices
@@ -1492,6 +1518,7 @@ def _reduce_with_indices(input, axis, combine_fn, _builder=None, _generator=None
# Scans
# -----------------------
def _add_scan_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]:
def _decorator(func: T) -> T:
@@ -1516,8 +1543,7 @@ def associative_scan(input, axis, combine_fn, _builder=None, _generator=None):
"""
if isinstance(input, tensor):
return associative_scan((input,), axis, combine_fn,
_builder=_builder, _generator=_generator)[0]
return associative_scan((input, ), axis, combine_fn, _builder=_builder, _generator=_generator)[0]
def make_combine_region(scan_op):
in_scalar_tys = [t.type.scalar for t in input]
@@ -1527,17 +1553,18 @@ def associative_scan(input, axis, combine_fn, _builder=None, _generator=None):
with _insertion_guard(_builder):
param_types = [ty.to_ir(_builder) for ty in prototype.param_types]
block = _builder.create_block_with_parent(region, param_types)
args = [tensor(block.arg(i), ty)
for i, ty in enumerate(prototype.param_types)]
args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)]
results = _generator.call_JitFunction(combine_fn, args, kwargs={})
if isinstance(results, tensor):
handles = [results.handle]
else:
handles = [r.handle for r in results]
_builder.create_scan_ret(*handles)
axis = _constexpr_to_value(axis)
return semantic.associative_scan(input, axis, make_combine_region, _builder)
# -----------------------
# Compiler Hint Ops
# -----------------------
@@ -1600,6 +1627,8 @@ def max_constancy(input, values, _builder=None):
raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
values = [x.value for x in values]
return semantic.max_constancy(input, values)
# -----------------------
# Debugging functions
# -----------------------
@@ -1739,12 +1768,12 @@ def inline_asm_elementwise(asm: str, constraints: str, args: list, dtype, is_pur
broadcast_arg = dispatch_args[0]
# Get the broadcast shape over all the arguments
for i, item in enumerate(dispatch_args):
_, broadcast_arg = semantic.binary_op_type_checking_impl(
item, broadcast_arg, _builder, arithmetic_check=False)
_, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder,
arithmetic_check=False)
# Change the shape of each argument based on the broadcast shape
for i in range(len(dispatch_args)):
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
dispatch_args[i], broadcast_arg, _builder, arithmetic_check=False)
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder,
arithmetic_check=False)
ret_shape = broadcast_arg.shape
res_ty = block_type(dtype, ret_shape)
call = _builder.create_inline_asm(asm, constraints, [t.handle for t in args], res_ty.to_ir(_builder), is_pure, pack)
@@ -1757,7 +1786,6 @@ def inline_asm_elementwise(asm: str, constraints: str, args: list, dtype, is_pur
class static_range:
"""
Iterator that counts upward forever.
@@ -1801,7 +1829,9 @@ class static_range:
# Extern functions
# -----------------------
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, is_pure: bool, _builder=None):
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple,
is_pure: bool, _builder=None):
'''
Dispatch a function to a library
:param func: the function to dispatch
@@ -1843,7 +1873,8 @@ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dic
return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder), is_pure), ret_type)
def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, _builder=None):
def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool,
_builder=None):
'''
Dispatch an elementwise function to a library
:param lib_name: the name of the library
@@ -1872,12 +1903,12 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol
broadcast_arg = dispatch_args[0]
# Get the broadcast shape over all the arguments
for i, item in enumerate(dispatch_args):
_, broadcast_arg = semantic.binary_op_type_checking_impl(
item, broadcast_arg, _builder, arithmetic_check=arithmetic_check)
_, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder,
arithmetic_check=arithmetic_check)
# Change the shape of each argument based on the broadcast shape
for i in range(len(dispatch_args)):
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
dispatch_args[i], broadcast_arg, _builder, arithmetic_check=arithmetic_check)
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder,
arithmetic_check=arithmetic_check)
if not all_scalar:
ret_shape = broadcast_arg.shape
func = getattr(_builder, "create_extern_elementwise")

View File

@@ -3,16 +3,14 @@ from .. import core
@core.extern
def globaltimer(_builder=None):
return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [],
dtype=core.int64, is_pure=False,
pack=1, _builder=_builder)
return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], dtype=core.int64, is_pure=False, pack=1,
_builder=_builder)
@core.extern
def smid(_builder=None):
return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [],
dtype=core.int32, is_pure=True,
pack=1, _builder=_builder)
return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1,
_builder=_builder)
@core.builtin

File diff suppressed because it is too large Load Diff

View File

@@ -91,6 +91,7 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
# two_to_the_minus_32: tl.constexpr = 2.328306e-10
# return x * two_to_the_minus_32
@jit
def uint32_to_uniform_float(x):
"""
@@ -134,6 +135,7 @@ def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
u4 = uint32_to_uniform_float(i4)
return u1, u2, u3, u4
# -------------------
# randn
# -------------------

File diff suppressed because it is too large Load Diff

View File

@@ -123,6 +123,7 @@ def maximum(x, y):
"""
return math.max(x, y)
# max and argmax
@@ -149,8 +150,7 @@ def _argmax_combine_tie_break_fast(value1, index1, value2, index2):
@jit
@core._add_reduction_docstr("maximum",
return_indices_arg="return_indices",
@core._add_reduction_docstr("maximum", return_indices_arg="return_indices",
tie_break_arg="return_indices_tie_break_left")
def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True):
input = core._promote_reduction_input(input)
@@ -175,6 +175,7 @@ def argmax(input, axis, tie_break_left=True):
(_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left)
return ret
# min and argmin
@@ -201,8 +202,7 @@ def _argmin_combine_tie_break_fast(value1, index1, value2, index2):
@jit
@core._add_reduction_docstr("minimum",
return_indices_arg="return_indices",
@core._add_reduction_docstr("minimum", return_indices_arg="return_indices",
tie_break_arg="return_indices_tie_break_left")
def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True):
input = core._promote_reduction_input(input)
@@ -222,8 +222,7 @@ def min(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr
@jit
@core._add_reduction_docstr("minimum index",
tie_break_arg="tie_break_left")
@core._add_reduction_docstr("minimum index", tie_break_arg="tie_break_left")
def argmin(input, axis, tie_break_left=True):
_, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left)
return ret
@@ -233,6 +232,7 @@ def argmin(input, axis, tie_break_left=True):
def _sum_combine(a, b):
return a + b
# sum
@@ -247,6 +247,7 @@ def sum(input, axis=None):
def _xor_combine(a, b):
return a ^ b
# xor sum
@@ -258,8 +259,8 @@ def xor_sum(input, axis=None, _builder=None, _generator=None):
raise ValueError("xor_sum only supported for integers")
input = core._promote_reduction_input(input, _builder=_builder)
return core.reduce(input, axis, _xor_combine,
_builder=_builder, _generator=_generator)
return core.reduce(input, axis, _xor_combine, _builder=_builder, _generator=_generator)
# cumsum
@@ -271,6 +272,7 @@ def cumsum(input, axis=0):
input = core._promote_reduction_input(input)
return core.associative_scan(input, axis, _sum_combine)
# cumprod

View File

@@ -17,15 +17,14 @@ from ... import language as tl
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
})
@jit
def _sdd_kernel(
A, B, C,
stride_za, stride_ha, stride_ma, stride_ak,
stride_zb, stride_hb, stride_bk, stride_nb,
stride_zc, stride_hc, stride_mc, stride_nc,
K, grid_offset, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
BLOCK: tl.constexpr, EVEN_K: tl.constexpr
):
def _sdd_kernel(A, B, C, #
stride_za, stride_ha, stride_ma, stride_ak, #
stride_zb, stride_hb, stride_bk, stride_nb, #
stride_zc, stride_hc, stride_mc, stride_nc, #
K, grid_offset, lut, #
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, #
BLOCK: tl.constexpr, EVEN_K: tl.constexpr #
):
# ------------ #
# - Prologue - #
# ------------ #
@@ -104,13 +103,13 @@ def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=
c = out
grid = [c.shape[1], 1, c.shape[0]]
_sdd_kernel[grid](
a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
Ka, 0, lut,
TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4,
num_warps=4,
a, b, c, #
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), #
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), #
c.stride(0), c.stride(1), c.stride(2), c.stride(3), #
Ka, 0, lut, #
TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4, #
num_warps=4 #
)
return c
@@ -120,6 +119,7 @@ def sdd_lut(layout, block, device):
lut = lut.contiguous()
return lut, None
# -----------------------------
# Dense = Sparse x Dense (DSD)
# This operation uses a look-up table that contains pre-computed pointer increments
@@ -128,15 +128,14 @@ def sdd_lut(layout, block, device):
@jit
def _dsd_kernel(
A, B, C,
stride_az, stride_ha, stride_am, stride_ak,
stride_zb, stride_hb, stride_bk, stride_bn,
stride_zc, stride_hc, stride_cm, stride_cn,
DS0, DS1, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr
):
def _dsd_kernel(A, B, C, #
stride_az, stride_ha, stride_am, stride_ak, #
stride_zb, stride_hb, stride_bk, stride_bn, #
stride_zc, stride_hc, stride_cm, stride_cn, #
DS0, DS1, lut, #
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr #
):
# ------------ #
# - Prologue - #
# ------------ #
@@ -229,13 +228,13 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=N
# compute output
grid = lambda meta: [cdiv(BS3, meta['TILE_N']), width, BS0]
_dsd_kernel[grid](
a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
BS3, AS1, lut,
TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
num_warps=4, GROUP_SIZE_M=4,
a, b, c, #
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), #
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), #
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), #
BS3, AS1, lut, #
TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4, #
num_warps=4, GROUP_SIZE_M=4 #
)
# exit()
return c
@@ -337,6 +336,7 @@ def dsd_lut(layout, block, step, trans, device):
# create locks
return lut, width
# -----------------------------
# Dense = Dense x Sparse (DDS)
# -----------------------------
@@ -346,6 +346,7 @@ def dsd_lut(layout, block, step, trans, device):
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
return dsd_matmul(b, a, not trans_b, not trans_a, not trans_c, spdims, block, lut, width, out=out)
##############
# MAIN API #
##############
@@ -356,10 +357,8 @@ class _matmul(torch.autograd.Function):
fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul}
@staticmethod
def forward(
ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block,
c_lut, c_width, da_lut, da_width, db_lut, db_width, out
):
def forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_width, da_lut, da_width, db_lut,
db_width, out):
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out)
# save for backward
ctx.save_for_backward(a, b)
@@ -385,15 +384,13 @@ class _matmul(torch.autograd.Function):
# gradients w.r.t. a
if ctx.needs_input_grad[0]:
mode_da = mode[1] + mode[0] + mode[2]
da = _matmul.fn[mode_da](
dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_width,
)
da = _matmul.fn[mode_da](dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block,
ctx.da_lut, ctx.da_width)
# gradients w.r.t. b
if ctx.needs_input_grad[1]:
mode_db = mode[2] + mode[1] + mode[0]
db = _matmul.fn[mode_db](
a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_width,
)
db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block,
ctx.db_lut, ctx.db_width)
dout = dc if ctx.has_out else None
return da, db, None, None, None, \
None, None, None, None, \
@@ -427,11 +424,9 @@ class matmul:
self.db_lut, self.db_width = sdd_lut(layout, block, device)
def __call__(self, a, b, out=None):
c = _matmul.apply(
a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block,
self.c_lut, self.c_width,
self.da_lut, self.da_width,
self.db_lut, self.db_width,
out
)
c = _matmul.apply(a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, #
self.c_lut, self.c_width, #
self.da_lut, self.da_width, #
self.db_lut, self.db_width, #
out)
return c

View File

@@ -18,14 +18,13 @@ def num_warps(n):
@jit
def _blocksparse_softmax_fwd(
Out, A, stride_xz, LUT,
R, extent, stride_zr, stride_hr, # relative attention
scale, is_causal,
ROW_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
IS_DENSE: tl.constexpr,
):
def _blocksparse_softmax_fwd(Out, A, stride_xz, LUT, #
R, extent, stride_zr, stride_hr, # relative attention
scale, is_causal, #
ROW_SIZE: tl.constexpr, #
BLOCK_SIZE: tl.constexpr, #
IS_DENSE: tl.constexpr #
):
h = tl.program_id(0)
m = tl.program_id(1)
z = tl.program_id(2)
@@ -73,18 +72,16 @@ def _blocksparse_softmax_fwd(
@jit
def _blocksparse_softmax_bwd(
DA, stride_zdx,
DOut, stride_zdout,
Out, stride_zout,
scale,
LUT,
DR, extent, stride_zr, stride_hr, stride_er,
is_causal,
ROW_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
IS_DENSE: tl.constexpr,
):
def _blocksparse_softmax_bwd(DA, stride_zdx, #
DOut, stride_zdout, #
Out, stride_zout, #
scale, #
LUT, #
DR, extent, stride_zr, stride_hr, stride_er, #
is_causal, #
ROW_SIZE: tl.constexpr, #
BLOCK_SIZE: tl.constexpr, #
IS_DENSE: tl.constexpr):
h = tl.program_id(0)
m = tl.program_id(1)
z = tl.program_id(2)
@@ -133,6 +130,7 @@ def _blocksparse_softmax_bwd(
class _softmax(torch.autograd.Function):
@staticmethod
def make_lut(layout, block, device):
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
@@ -151,10 +149,7 @@ class _softmax(torch.autograd.Function):
return lut, int(total_sizes.max())
@staticmethod
def forward(
ctx, a, scale, rel_logits, is_causal,
spdims, block, lut, maxlut, is_dense
):
def forward(ctx, a, scale, rel_logits, is_causal, spdims, block, lut, maxlut, is_dense):
if scale is not None and isinstance(scale, torch.Tensor):
assert scale.device.type == "cpu"
scale = scale.item()
@@ -165,14 +160,14 @@ class _softmax(torch.autograd.Function):
# enqueue kernel
out = torch.empty_like(a)
_blocksparse_softmax_fwd[grid](
out, a, a.stride(0), lut,
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn
scale,
is_causal,
BLOCK_SIZE=block,
ROW_SIZE=next_power_of_2(maxlut),
IS_DENSE=is_dense,
num_warps=num_warps(maxlut)
out, a, a.stride(0), lut, #
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn#
scale, #
is_causal, #
BLOCK_SIZE=block, #
ROW_SIZE=next_power_of_2(maxlut), #
IS_DENSE=is_dense, #
num_warps=num_warps(maxlut) #
)
# save to context
# ctx.mark_dirty(x)
@@ -201,28 +196,23 @@ class _softmax(torch.autograd.Function):
grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M)
da = torch.empty_like(dout)
_blocksparse_softmax_bwd[grid](
da, da.stride(0),
dout, dout.stride(0),
out, out.stride(0),
ctx.scale,
lut,
dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2],
ctx.is_causal,
BLOCK_SIZE=ctx.block,
ROW_SIZE=next_power_of_2(ctx.maxlut),
IS_DENSE=ctx.is_dense,
num_warps=num_warps(ctx.maxlut)
da, da.stride(0), #
dout, dout.stride(0), #
out, out.stride(0), #
ctx.scale, #
lut, #
dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2], #
ctx.is_causal, #
BLOCK_SIZE=ctx.block, #
ROW_SIZE=next_power_of_2(ctx.maxlut), #
IS_DENSE=ctx.is_dense, #
num_warps=num_warps(ctx.maxlut) #
)
return (da, None, None, dr, None,
None, None, None, None, None,
None,
None, None, None,
None,
None, None, None
)
return (da, None, None, dr, None, None, None, None, None, None, None, None, None, None, None, None, None, None)
class softmax:
def __init__(self, layout, block, device, is_dense=False):
self.spdims = layout.shape
self.layout = layout
@@ -233,8 +223,6 @@ class softmax:
def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False):
if rel_logits is not None and rel_logits.dtype != a.dtype:
raise ValueError(f"relative position embedding must be {a.dtype}")
a = _softmax.apply(
a, scale, rel_logits, is_causal,
self.spdims, self.block, self.lut, self.maxlut, self.is_dense,
)
a = _softmax.apply(a, scale, rel_logits, is_causal, self.spdims, self.block, self.lut, self.maxlut,
self.is_dense)
return a

View File

@@ -59,6 +59,7 @@ def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
class _cross_entropy(torch.autograd.Function):
@classmethod
def forward(cls, ctx, logits, indices):
# make sure we can use triton

View File

@@ -15,20 +15,19 @@ from .. import language as tl
@jit
def _fwd_kernel(
Q, K, V, sm_scale,
L,
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
stride_oz, stride_oh, stride_om, stride_on,
Z, H, N_CTX,
Z_H_N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
IS_CAUSAL: tl.constexpr,
):
def _fwd_kernel(Q, K, V, sm_scale, #
L, #
Out, #
stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vn, stride_vk, #
stride_oz, stride_oh, stride_om, stride_on, #
Z, H, N_CTX, #
Z_H_N_CTX, #
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
BLOCK_N: tl.constexpr, #
IS_CAUSAL: tl.constexpr #
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
qvk_offset = off_hz * stride_qh
@@ -40,7 +39,7 @@ def _fwd_kernel(
strides=(stride_kk, stride_kn),
offsets=(0, vk_offset),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
order=(0, 1),
)
V_block_ptr = tl.make_block_ptr(
base=V,
@@ -48,7 +47,7 @@ def _fwd_kernel(
strides=(stride_vn, stride_vk),
offsets=(vk_offset, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
order=(1, 0),
)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
@@ -104,7 +103,7 @@ def _fwd_kernel(
strides=(stride_om, stride_on),
offsets=(vk_offset + start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
order=(1, 0),
)
# O_ptrs = Out + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk
tl.store(O_block_ptr, acc.to(K.dtype.element_ty))
@@ -112,9 +111,11 @@ def _fwd_kernel(
@jit
def _bwd_preprocess(
Out, DO,
Out,
DO,
Delta,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
BLOCK_M: tl.constexpr,
D_HEAD: tl.constexpr,
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD)
@@ -128,40 +129,48 @@ def _bwd_preprocess(
@jit
def _bwd_kernel_one_col_block(
Q, K, V, sm_scale, qk_scale,
Out, DO,
DQ, DK, DV,
L,
D,
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
Z, H, N_CTX,
off_hz, start_n, num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr,
CAUSAL: tl.constexpr,
MMA_V3: tl.constexpr
):
if SEQUENCE_PARALLEL:
DQ += stride_dqa.to(tl.int64) * start_n
def _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, #
Out, DO, #
DQ, DK, DV, #
L, #
D, #
Q_block_ptr, K_block_ptr, V_block_ptr, #
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vn, stride_vk, #
Z, H, N_CTX, #
off_h, off_z, off_hz, start_n, num_block, #
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
BLOCK_N: tl.constexpr, #
SEQUENCE_PARALLEL: tl.constexpr, #
CAUSAL: tl.constexpr, #
MMA_V3: tl.constexpr #
):
if CAUSAL:
lo = start_n * BLOCK_M
else:
lo = 0
Q_offset = (off_z * stride_qz + off_h * stride_qh) // stride_qm
DQ_offset = off_z * stride_qz + off_h * stride_qh
K_offset = (off_z * stride_kz + off_h * stride_kh) // stride_kn
V_offset = (off_z * stride_vz + off_h * stride_vh) // stride_vn
if SEQUENCE_PARALLEL:
DQ_offset += stride_dqa.to(tl.int64) * start_n
DQ_offset = DQ_offset // stride_qm
Q_block_ptr = tl.advance(Q_block_ptr, (lo + Q_offset, 0))
K_block_ptr = tl.advance(K_block_ptr, (start_n * BLOCK_M + K_offset, 0))
V_block_ptr = tl.advance(V_block_ptr, (start_n * BLOCK_M + V_offset, 0))
DO_block_ptr = tl.advance(DO_block_ptr, (lo + Q_offset, 0))
DQ_block_ptr = tl.advance(DQ_block_ptr, (lo + DQ_offset, 0))
DK_block_ptr = tl.advance(DK_block_ptr, (start_n * BLOCK_M + K_offset, 0))
DV_block_ptr = tl.advance(DV_block_ptr, (start_n * BLOCK_M + V_offset, 0))
# initialize row/col offsets
offs_qm = lo + tl.arange(0, BLOCK_M)
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
offs_m = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_DMODEL)
# initialize pointers to value-like data
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk)
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
l_ptrs = L + off_hz * N_CTX
@@ -169,17 +178,17 @@ def _bwd_kernel_one_col_block(
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# k and v stay in SRAM throughout
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
k = tl.load(K_block_ptr)
v = tl.load(V_block_ptr)
# loop over rows
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip
q = tl.load(q_ptrs)
q = tl.load(Q_block_ptr)
# recompute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here
if CAUSAL:
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf"))
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.0), float("-inf"))
else:
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
@@ -187,7 +196,7 @@ def _bwd_kernel_one_col_block(
l_i = tl.load(l_ptrs + offs_m_curr)
p = tl.math.exp2(qk - l_i[:, None])
# compute dv
do = tl.load(do_ptrs)
do = tl.load(DO_block_ptr)
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do, allow_tf32=True)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
@@ -199,97 +208,156 @@ def _bwd_kernel_one_col_block(
dk += tl.dot(tl.trans(ds), q, allow_tf32=True)
# compute dq
if not SEQUENCE_PARALLEL:
dq = tl.load(dq_ptrs)
dq = tl.load(DQ_block_ptr)
dq += tl.dot(ds, k, allow_tf32=True)
tl.store(dq_ptrs, dq)
tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty))
elif SEQUENCE_PARALLEL:
if MMA_V3:
dq = tl.dot(ds, k, allow_tf32=True)
else:
# not work with mma v3, becuase M % 64 != 0
dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds), allow_tf32=True))
tl.store(dq_ptrs, dq)
tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty))
# increment pointers
dq_ptrs += BLOCK_M * stride_qm
q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_qm
DQ_block_ptr = tl.advance(DQ_block_ptr, (BLOCK_M, 0))
Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_M, 0))
DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_M, 0))
# write-back
dv_ptrs = DV + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk)
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
tl.store(DV_block_ptr, dv.to(V.dtype.element_ty))
tl.store(DK_block_ptr, dk.to(K.dtype.element_ty))
@jit
def _bwd_kernel(
# fmt: off
Q, K, V, sm_scale,
Out, DO,
DQ, DK, DV,
L,
D,
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr,
CAUSAL: tl.constexpr,
MMA_V3: tl.constexpr
# fmt: on
):
def _bwd_kernel(Q, K, V, sm_scale, #
Out, DO, #
DQ, DK, DV, #
L, #
D, #
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vn, stride_vk, #
Z, H, N_CTX, #
Z_H_N_CTX, #
SQ_Z_H_N_CTX, #
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
BLOCK_N: tl.constexpr, #
SEQUENCE_PARALLEL: tl.constexpr, #
CAUSAL: tl.constexpr, #
MMA_V3: tl.constexpr #
):
qk_scale = sm_scale * 1.44269504
off_hz = tl.program_id(0)
off_z = off_hz // H
off_h = off_hz % H
# offset pointers for batch/head
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_kz + off_h * stride_kh
V += off_z * stride_vz + off_h * stride_vh
DO += off_z * stride_qz + off_h * stride_qh
DQ += off_z * stride_qz + off_h * stride_qh
DK += off_z * stride_kz + off_h * stride_kh
DV += off_z * stride_vz + off_h * stride_vh
Q_block_ptr = tl.make_block_ptr(
base=Q,
shape=(Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
K_block_ptr = tl.make_block_ptr(
base=K,
shape=(Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_kn, stride_kk),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
V_block_ptr = tl.make_block_ptr(
base=V,
shape=(Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_vn, stride_vk),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
DO_block_ptr = tl.make_block_ptr(
base=DO,
shape=(Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
if SEQUENCE_PARALLEL:
DQ_block_ptr = tl.make_block_ptr(
base=DQ,
shape=(SQ_Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
else:
DQ_block_ptr = tl.make_block_ptr(
base=DQ,
shape=(Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
DK_block_ptr = tl.make_block_ptr(
base=DK,
shape=(Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_kn, stride_kk),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
DV_block_ptr = tl.make_block_ptr(
base=DV,
shape=(Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_vn, stride_vk),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
num_block_n = tl.cdiv(N_CTX, BLOCK_N)
if not SEQUENCE_PARALLEL:
for start_n in range(0, num_block_n):
_bwd_kernel_one_col_block(
Q, K, V, sm_scale, qk_scale, Out, DO,
DQ, DK, DV,
L,
D,
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
Z, H, N_CTX,
off_hz, start_n, num_block_n,
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_N=BLOCK_N,
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
CAUSAL=CAUSAL,
MMA_V3=MMA_V3
)
_bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, #
DQ, DK, DV, #
L, #
D, #
Q_block_ptr, K_block_ptr, V_block_ptr, #
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vn, stride_vk, #
Z, H, N_CTX, #
off_h, off_z, off_hz, start_n, num_block_n, #
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, #
BLOCK_N=BLOCK_N, #
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, #
CAUSAL=CAUSAL, #
MMA_V3=MMA_V3 #
)
else:
start_n = tl.program_id(1)
_bwd_kernel_one_col_block(
Q, K, V, sm_scale, qk_scale, Out, DO,
DQ, DK, DV,
L,
D,
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
Z, H, N_CTX,
off_hz, start_n, num_block_n,
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_N=BLOCK_N,
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
CAUSAL=CAUSAL,
MMA_V3=MMA_V3
)
_bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, #
DQ, DK, DV, #
L, #
D, #
Q_block_ptr, K_block_ptr, V_block_ptr, #
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vn, stride_vk, #
Z, H, N_CTX, #
off_h, off_z, off_hz, start_n, num_block_n, #
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, #
BLOCK_N=BLOCK_N, #
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, #
CAUSAL=CAUSAL, #
MMA_V3=MMA_V3 #
)
class _attention(torch.autograd.Function):
@@ -315,19 +383,20 @@ class _attention(torch.autograd.Function):
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
_fwd_kernel[grid](
q, k, v, sm_scale,
L,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
q.shape[0] * q.shape[1] * q.shape[2],
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
IS_CAUSAL=causal,
num_warps=num_warps,
num_stages=4)
q, k, v, sm_scale, #
L, #
o, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
q.shape[0], q.shape[1], q.shape[2], #
q.shape[0] * q.shape[1] * q.shape[2], #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, #
IS_CAUSAL=causal, #
num_warps=num_warps, #
num_stages=4 #
)
ctx.save_for_backward(q, k, v, o, L)
ctx.grid = grid
@@ -348,35 +417,39 @@ class _attention(torch.autograd.Function):
do = do.contiguous()
if sequence_parallel:
replicas = cdiv(seq_len_kv, BLOCK)
new_dq_shape = (replicas,) + q.shape
new_dq_shape = (replicas, ) + q.shape
dq = torch.zeros(new_dq_shape, device=q.device, dtype=q.dtype)
else:
dq = torch.zeros_like(q, dtype=torch.float32)
dq = torch.zeros_like(q, dtype=q.dtype)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
delta = torch.empty_like(L)
_bwd_preprocess[(cdiv(q.shape[2], BLOCK) * ctx.grid[1], )](
o, do,
o,
do,
delta,
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
BLOCK_M=BLOCK,
D_HEAD=ctx.BLOCK_DMODEL,
)
_bwd_kernel[(ctx.grid[1], cdiv(seq_len_kv, BLOCK) if sequence_parallel else 1)](
q, k, v, ctx.sm_scale,
o, do,
dq, dk, dv,
L,
delta,
o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
SEQUENCE_PARALLEL=sequence_parallel,
CAUSAL=ctx.causal,
MMA_V3=MMA_V3,
num_warps=8,
num_stages=1,
q, k, v, ctx.sm_scale, #
o, do, #
dq, dk, dv, #
L, #
delta, #
o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
q.shape[0], q.shape[1], q.shape[2], #
q.shape[0] * q.shape[1] * q.shape[2], #
cdiv(seq_len_kv, BLOCK) * q.shape[0] * q.shape[1] * q.shape[2], #
BLOCK_M=BLOCK, BLOCK_N=BLOCK, #
BLOCK_DMODEL=ctx.BLOCK_DMODEL, #
SEQUENCE_PARALLEL=sequence_parallel, #
CAUSAL=ctx.causal, #
MMA_V3=MMA_V3, #
num_warps=8, #
num_stages=1 #
)
if len(dq.shape) == 5:

View File

@@ -37,8 +37,9 @@ def get_configs_io_bound():
num_stages=num_stages, num_warps=num_warps))
# split_k
for split_k in [2, 4, 8, 16]:
configs.append(Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
configs.append(
Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return configs
@@ -69,22 +70,22 @@ def get_configs_io_bound():
prune_configs_by={
'early_config_prune': early_config_prune,
'perf_model': estimate_matmul_time,
'top_k': 10
'top_k': 10,
},
)
@heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@jit
def _kernel(A, B, C, M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
dot_out_dtype: tl.constexpr,
allow_tf32: tl.constexpr,
fp8_fast_accum: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr
def _kernel(A, B, C, M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
dot_out_dtype: tl.constexpr, #
allow_tf32: tl.constexpr, #
fp8_fast_accum: tl.constexpr, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr #
):
# matrix multiplication
pid = tl.program_id(0)
@@ -184,14 +185,15 @@ class _matmul(torch.autograd.Function):
ab_dtype = False
# launch kernel
grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
dot_out_dtype=dot_out_dtype,
allow_tf32=allow_tf32,
fp8_fast_accum=fp8_fast_accum,
GROUP_M=8, AB_DTYPE=ab_dtype)
_kernel[grid](
a, b, c, M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
dot_out_dtype=dot_out_dtype, #
allow_tf32=allow_tf32, #
fp8_fast_accum=fp8_fast_accum, #
GROUP_M=8, AB_DTYPE=ab_dtype)
return c
@staticmethod

View File

@@ -5,8 +5,7 @@ import torch
from .. import cdiv
from .._C.libtriton.triton import runtime
from ..runtime import driver
from ..testing import (get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops,
nvsmi)
from ..testing import (get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops, nvsmi)
def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
@@ -14,7 +13,8 @@ def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
total_warps = num_ctas * min(num_warps, 4)
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, cur_sm_clock, backend, device)
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(
dtype, cur_sm_clock, backend, device)
return tflops
@@ -35,12 +35,12 @@ def get_tflops(backend, device, num_ctas, num_warps, dtype):
def estimate_matmul_time(
# backend, device,
num_warps, num_stages,
A, B, C,
M, N, K,
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K,
debug=False, **kwargs
# backend, device,
num_warps, num_stages, #
A, B, C, #
M, N, K, #
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, #
debug=False, **kwargs #
):
''' return estimated running time in ms
= max(compute, loading) + store '''
@@ -149,8 +149,9 @@ def early_config_prune(configs, named_args):
optimal_num_stages = ldgsts_latency / mma_cycles
# nearest stages, prefer large #stages
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages)
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
nearest = heapq.nsmallest(
2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages)
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
for n in nearest:
pruned_configs.append(n[0])

View File

@@ -1,8 +1,6 @@
from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune,
heuristics)
from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune, heuristics)
from .driver import driver
from .jit import (JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret,
version_key)
from .jit import JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret
__all__ = [
"driver",
@@ -12,7 +10,6 @@ __all__ = [
"heuristics",
"JITFunction",
"KernelInterface",
"version_key",
"reinterpret",
"TensorWrapper",
"OutOfResources",

View File

@@ -9,11 +9,10 @@ from .jit import KernelInterface
class OutOfResources(Exception):
def __init__(self, required, limit, name):
self.message = f'out of resource: {name}, '\
f'Required: {required}, '\
f'Hardware limit: {limit}'
self.message += '. Reducing block sizes or `num_stages` may help.'
self.message = (f"out of resource: {name}, Required: {required}, Hardware limit: {limit}. " +
"Reducing block sizes or `num_stages` may help.")
self.required = required
self.limit = limit
self.name = name
@@ -25,38 +24,77 @@ class OutOfResources(Exception):
class Autotuner(KernelInterface):
<<<<<<< HEAD
def __init__(self, fn, arg_names, configs, key, verbose, reset_to_zero, prune_configs_by: Dict = None, warmup=25, rep=100):
'''
=======
def __init__(
self,
fn,
arg_names,
configs,
key,
reset_to_zero,
restore_value,
prune_configs_by: Dict = None,
warmup=25,
rep=100,
):
"""
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
'''
"""
if not configs:
self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)]
else:
self.configs = configs
self.key_idx = [arg_names.index(k) for k in key]
self.cache = {}
# hook to reset all required tensor to zeros before relaunching a kernel
self.hook = lambda args: 0
self.arg_names = arg_names
# Reset to zero or restore values
self.reset_idx = []
if reset_to_zero is not None:
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
self.restore_idx = []
if restore_value is not None:
self.restore_idx = [arg_names.index(k) for k in restore_value]
def _hook(args):
# Hook to reset or restore for required tensors
self.pre_hook = lambda args, reset_only=False: 0
self.post_hook = lambda args: 0
if len(self.reset_idx) > 0 or len(self.restore_idx) > 0:
def _pre_hook(args, reset_only=False):
for i in self.reset_idx:
args[i].zero_()
self.hook = _hook
self.arg_names = arg_names
# prune configs
if not reset_only:
self.restore_copies = [args[i].clone() for i in self.restore_idx]
self.pre_hook = _pre_hook
if len(self.restore_idx) > 0:
def _post_hook(args):
for i, j in enumerate(self.restore_idx):
args[j].copy_(self.restore_copies[i])
self.restore_copies = []
self.post_hook = _post_hook
# Prune configs
if prune_configs_by:
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
if 'early_config_prune' in prune_configs_by:
early_config_prune = prune_configs_by['early_config_prune']
perf_model, top_k = prune_configs_by["perf_model"], prune_configs_by["top_k"]
if "early_config_prune" in prune_configs_by:
early_config_prune = prune_configs_by["early_config_prune"]
else:
perf_model, top_k, early_config_prune = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k
self.early_config_prune = early_config_prune
self.fn = fn
self.warmup = warmup
self.rep = rep
@@ -67,10 +105,8 @@ class Autotuner(KernelInterface):
# as kwargs and by the autotuner
conflicts = meta.keys() & config.kwargs.keys()
if conflicts:
raise ValueError(
f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols."
)
raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols.")
# augment meta-parameters with tunable ones
current = dict(meta, **config.kwargs)
full_nargs = {**self.nargs, **current}
@@ -78,16 +114,22 @@ class Autotuner(KernelInterface):
def kernel_call():
if config.pre_hook:
config.pre_hook(full_nargs)
self.hook(args)
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization,
# enable_persistent=False,
**current)
self.pre_hook(args)
self.fn.run(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization,
# enable_persistent=False,
**current,
)
self.post_hook(args)
try:
return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
except OutOfResources:
return [float('inf'), float('inf'), float('inf')]
return [float("inf"), float("inf"), float("inf")]
def get_best_config(self):
return self.best_config
@@ -110,12 +152,11 @@ class Autotuner(KernelInterface):
# prune configs
pruned_configs = self.prune_configs(kwargs)
bench_start = time.time()
timings = {config: self._bench(*args, config=config, **kwargs)
for config in pruned_configs}
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
bench_end = time.time()
self.bench_time = bench_end - bench_start
self.cache[key] = builtins.min(timings, key=timings.get)
self.hook(args)
self.pre_hook(args, reset_only=True)
self.configs_timings = timings
if self.verbose:
print(str(key) + ": " + str(self.cache[key]))
@@ -126,9 +167,15 @@ class Autotuner(KernelInterface):
full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs}
if config.pre_hook is not None:
config.pre_hook(full_nargs)
ret = self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization, **kwargs, **config.kwargs)
ret = self.fn.run(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization,
**kwargs,
**config.kwargs,
)
self.nargs = None
return ret
@@ -142,17 +189,20 @@ class Autotuner(KernelInterface):
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {
config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages,
num_warps=config.num_warps,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization,
enable_persistent=config.enable_persistent)
config:
self.perf_model(
**self.nargs,
**kwargs,
**config.kwargs,
num_stages=config.num_stages,
num_warps=config.num_warps,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization,
enable_persistent=config.enable_persistent,
)
for config in pruned_configs
}
pruned_configs = sorted(
est_timing.keys(),
key=lambda x: est_timing[x])[
:top_k]
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
return pruned_configs
def warmup(self, *args, **kwargs):
@@ -195,13 +245,14 @@ class Config:
self.num_ctas = num_ctas
self.num_stages = num_stages
self.enable_warp_specialization = enable_warp_specialization
# TODO[shuhaoj]: May make enable_persistent configurable in future if necessay.
# TODO[shuhaoj]: May make enable_persistent configurable in future if necessary.
self.enable_persistent = False
self.pre_hook = pre_hook
def __str__(self):
res = []
for k, v in self.kwargs.items():
<<<<<<< HEAD
res.append(f'{k}: {v}')
res.append(f'num_warps: {self.num_warps}')
## Comment out Hopper specific parameters
@@ -214,6 +265,18 @@ class Config:
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=False, warmup=25, rep=100):
=======
res.append(f"{k}: {v}")
res.append(f"num_warps: {self.num_warps}")
res.append(f"num_ctas: {self.num_ctas}")
res.append(f"num_stages: {self.num_stages}")
res.append(f"enable_warp_specialization: {self.enable_warp_specialization}")
res.append(f"enable_persistent: {self.enable_persistent}")
return ", ".join(res)
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, warmup=25, rep=100):
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
"""
Decorator for auto-tuning a :code:`triton.jit`'d function.
@@ -244,6 +307,8 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=Fa
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str]
:param restore_value: a list of argument names whose value will be restored after evaluating any configs.
:type restore_value: list[str]
:param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25.
:type warmup: int
:param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100.
@@ -251,8 +316,13 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=Fa
:param verbose: a boolean that controls whether the best_config for each key is printed
:type verbose: bool
"""
def decorator(fn):
<<<<<<< HEAD
return Autotuner(fn, fn.arg_names, configs, key, verbose, reset_to_zero, prune_configs_by, warmup, rep)
=======
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, prune_configs_by, warmup, rep)
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
return decorator
@@ -286,6 +356,7 @@ def heuristics(values):
each such function takes a list of positional arguments as input.
:type values: dict[str, Callable[[list[Any]], Any]]
"""
def decorator(fn):
return Heuristics(fn, fn.arg_names, values)

View File

@@ -1,27 +1,42 @@
#include "cuda.h"
#include <dlfcn.h>
#include <stdbool.h>
#define PY_SSIZE_T_CLEAN
#include <Python.h>
static inline void gpuAssert(CUresult code, const char *file, int line) {
if (code != CUDA_SUCCESS) {
const char *prefix = "Triton Error [CUDA]: ";
const char *str;
cuGetErrorString(code, &str);
char err[1024] = {0};
strcat(err, prefix);
strcat(err, str);
PyGILState_STATE gil_state;
gil_state = PyGILState_Ensure();
PyErr_SetString(PyExc_RuntimeError, err);
PyGILState_Release(gil_state);
}
// Raises a Python exception and returns false if code is not CUDA_SUCCESS.
static bool gpuAssert(CUresult code, const char *file, int line) {
if (code == CUDA_SUCCESS)
return true;
const char *prefix = "Triton Error [CUDA]: ";
const char *str;
cuGetErrorString(code, &str);
char err[1024] = {0};
strcat(err, prefix);
strcat(err, str);
PyGILState_STATE gil_state;
gil_state = PyGILState_Ensure();
PyErr_SetString(PyExc_RuntimeError, err);
PyGILState_Release(gil_state);
return false;
}
#define CUDA_CHECK(ans) \
{ \
{ gpuAssert((ans), __FILE__, __LINE__); } \
}
// To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block.
#define CUDA_CHECK_AND_RETURN_NULL(ans) \
do { \
if (!gpuAssert((ans), __FILE__, __LINE__)) \
return NULL; \
} while (0)
// To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block.
#define CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(ans) \
do { \
if (!gpuAssert((ans), __FILE__, __LINE__)) { \
PyEval_RestoreThread(_save); \
return NULL; \
} \
} while (0)
#define ADD_ENUM_ITEM(value) \
do { \
@@ -200,16 +215,16 @@ static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
int sm_clock_rate;
int mem_clock_rate;
int mem_bus_width;
CUDA_CHECK(cuDeviceGetAttribute(
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
&max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
device));
CUDA_CHECK(cuDeviceGetAttribute(
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
&multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate,
CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
CUDA_CHECK(cuDeviceGetAttribute(
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
&sm_clock_rate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
&mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));
CUDA_CHECK(cuDeviceGetAttribute(
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
&mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device));
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem",
@@ -237,33 +252,37 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
CUcontext pctx = 0;
Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK(cuCtxGetCurrent(&pctx));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&pctx));
if (!pctx) {
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
CUDA_CHECK(cuCtxSetCurrent(pctx));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuDevicePrimaryCtxRetain(&pctx, device));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(pctx));
}
CUDA_CHECK(cuModuleLoadData(&mod, data));
CUDA_CHECK(cuModuleGetFunction(&fun, mod, name));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuModuleLoadData(&mod, data));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuModuleGetFunction(&fun, mod, name));
// get allocated registers and spilled registers from the function
CUDA_CHECK(cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
CUDA_CHECK(
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
n_spills /= 4;
// set dynamic shared memory if necessary
int shared_optin;
CUDA_CHECK(cuDeviceGetAttribute(
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute(
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
device));
if (shared > 49152 && shared_optin > 49152) {
CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
int shared_total, shared_static;
CUDA_CHECK(cuDeviceGetAttribute(
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute(
&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
device));
CUDA_CHECK(cuFuncGetAttribute(&shared_static,
CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
CUDA_CHECK(
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute(
&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_optin - shared_static));
}
@@ -286,7 +305,7 @@ static PyObject *memAlloc(PyObject *self, PyObject *args) {
}
Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK(cuMemAlloc(&dptr, bytesize));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuMemAlloc(&dptr, bytesize));
Py_END_ALLOW_THREADS;
return PyLong_FromUnsignedLongLong((unsigned long long)dptr);
@@ -307,7 +326,8 @@ static PyObject *memcpyHtoD(PyObject *self, PyObject *args) {
srcHost = (const void *)srcHostPtr;
Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK(cuMemcpyHtoD(dstDevice, srcHost, byteCount));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuMemcpyHtoD(dstDevice, srcHost, byteCount));
Py_END_ALLOW_THREADS;
Py_RETURN_NONE;
@@ -321,7 +341,7 @@ static PyObject *memFree(PyObject *self, PyObject *args) {
}
Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK(cuMemFree(dptr));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuMemFree(dptr));
Py_END_ALLOW_THREADS;
Py_RETURN_NONE;
@@ -411,7 +431,7 @@ static PyObject *tensorMapEncodeTiled(PyObject *self, PyObject *args) {
}
// Call the function
Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK(cuTensorMapEncodeTiledHandle(
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuTensorMapEncodeTiledHandle(
tensorMap, tensorDataType, tensorRank, globalAddress, globalDim,
globalStrides, boxDim, elementStrides, interleave, swizzle, l2Promotion,
oobFill));

View File

@@ -19,6 +19,7 @@ def default_dump_dir():
class CacheManager(ABC):
def __init__(self, key):
pass
@@ -44,20 +45,21 @@ class CacheManager(ABC):
class FileCacheManager(CacheManager):
def __init__(self, key, override=False, dump=False):
self.key = key
self.lock_path = None
if (dump):
if dump:
self.cache_dir = default_dump_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
elif (override):
elif override:
self.cache_dir = default_override_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
else:
# create cache directory if it doesn't exist
self.cache_dir = os.getenv('TRITON_CACHE_DIR', "").strip() or default_cache_dir()
self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
if self.cache_dir:
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
@@ -93,9 +95,8 @@ class FileCacheManager(CacheManager):
result = {}
for c in child_paths:
p = self._make_path(c)
if not os.path.exists(p):
raise Exception(f"Group file {p} does not exist from group {grp_filename} ")
result[c] = p
if os.path.exists(p):
result[c] = p
return result
# Note a group of pushed files as being part of a group
@@ -142,6 +143,7 @@ def get_cache_manager(key) -> CacheManager:
if user_cache_manager is not None and user_cache_manager != __cache_cls_nme:
import importlib
module_path, clz_nme = user_cache_manager.split(":")
module = importlib.import_module(module_path)
__cache_cls = getattr(module, clz_nme)

View File

@@ -9,7 +9,6 @@ from .cache import get_cache_manager
class DriverBase(metaclass=abc.ABCMeta):
CUDA = 0
HIP = 1
@@ -19,6 +18,8 @@ class DriverBase(metaclass=abc.ABCMeta):
def __init__(self) -> None:
pass
# -----------------------------
# CUDA
# -----------------------------
@@ -27,7 +28,7 @@ class DriverBase(metaclass=abc.ABCMeta):
class CudaUtils(object):
def __new__(cls):
if not hasattr(cls, 'instance'):
if not hasattr(cls, "instance"):
cls.instance = super(CudaUtils, cls).__new__(cls)
return cls.instance
@@ -47,6 +48,7 @@ class CudaUtils(object):
with open(so, "rb") as f:
cache_path = cache.put(f.read(), fname, binary=True)
import importlib.util
spec = importlib.util.spec_from_file_location("cuda_utils", cache_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
@@ -66,7 +68,7 @@ class CudaUtils(object):
class CudaDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, 'instance'):
if not hasattr(cls, "instance"):
cls.instance = super(CudaDriver, cls).__new__(cls)
return cls.instance
@@ -74,14 +76,16 @@ class CudaDriver(DriverBase):
self.utils = CudaUtils()
self.backend = self.CUDA
# -----------------------------
# HIP
# -----------------------------
class HIPUtils(object):
def __new__(cls):
if not hasattr(cls, 'instance'):
if not hasattr(cls, "instance"):
cls.instance = super(HIPUtils, cls).__new__(cls)
return cls.instance
@@ -101,6 +105,7 @@ class HIPUtils(object):
with open(so, "rb") as f:
cache_path = cache.put(f.read(), fname, binary=True)
import importlib.util
spec = importlib.util.spec_from_file_location("hip_utils", cache_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
@@ -111,7 +116,7 @@ class HIPUtils(object):
class HIPDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, 'instance'):
if not hasattr(cls, "instance"):
cls.instance = super(HIPDriver, cls).__new__(cls)
return cls.instance
@@ -123,7 +128,7 @@ class HIPDriver(DriverBase):
class UnsupportedDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, 'instance'):
if not hasattr(cls, "instance"):
cls.instance = super(UnsupportedDriver, cls).__new__(cls)
return cls.instance
@@ -131,12 +136,14 @@ class UnsupportedDriver(DriverBase):
self.utils = None
self.backend = None
# -----------------------------
# Driver
# -----------------------------
class LazyProxy:
def __init__(self, init_fn):
self._init_fn = init_fn
self._obj = None
@@ -150,7 +157,7 @@ class LazyProxy:
return getattr(self._obj, name)
def __setattr__(self, name, value):
if name in ['_init_fn', '_obj']:
if name in ["_init_fn", "_obj"]:
super().__setattr__(name, value)
else:
self._initialize_obj()
@@ -172,6 +179,7 @@ class LazyProxy:
def initialize_driver():
import torch
if torch.version.hip is not None:
return HIPDriver()
elif torch.cuda.is_available():

View File

@@ -1,10 +1,8 @@
class OutOfResources(Exception):
def __init__(self, required, limit, name):
self.message = f'out of resource: {name}, '\
f'Required: {required}, '\
f'Hardware limit: {limit}'
self.message += '. Reducing block sizes or `num_stages` may help.'
self.message = f"out of resource: {name}, " f"Required: {required}, " f"Hardware limit: {limit}"
self.message += ". Reducing block sizes or `num_stages` may help."
self.required = required
self.limit = limit
self.name = name

View File

@@ -74,11 +74,15 @@ class BlockPointerHandle:
def wrap_ret(compute_ret_ty):
def wrapper(fn):
def wrapped(*args, **kwargs):
ret = fn(*args, **kwargs)
return TensorHandle(ret.data, compute_ret_ty(*args, **kwargs))
return wrapped
return wrapper
@@ -249,11 +253,13 @@ class Builder:
# ternary functions
def ternary_op(self, lhs, rhs, other, op):
return TensorHandle(op(lhs.data, rhs.data, other.data), other.dtype)
create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where)
# unary functions
def unary_op(self, arg, op):
return TensorHandle(op(arg.data), arg.dtype)
create_exp = lambda self, arg: self.unary_op(arg, np.exp)
create_cos = lambda self, arg: self.unary_op(arg, np.cos)
create_sin = lambda self, arg: self.unary_op(arg, np.sin)
@@ -279,7 +285,8 @@ class Builder:
dtype_tt = ptr.dtype.element_ty
return TensorHandle(ptr.data + (dtype_tt.primitive_bitwidth // 8) * offset.data.astype(np.uint64), ptr.dtype)
def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, is_volatile):
def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy,
is_volatile):
ptrs, masks = ptr.materialize_pointers(boundary_check)
assert padding_option is None
other = None
@@ -297,6 +304,7 @@ class Builder:
def create_int_to_ptr(self, val, dst_ty):
return TensorHandle(val.data.astype(np.uint64), dst_ty)
# def create_cat(self, lhs, rhs):
# pass
@@ -360,7 +368,10 @@ class Builder:
def patch_attr(obj, name, member, builder):
new_member = lambda *args, member=member, **kwargs: (member(*args, **{k: v for k, v in kwargs.items() if k != '_builder'}, _builder=builder))
new_member = lambda *args, member=member, **kwargs: (member(*args, **
{k: v
for k, v in kwargs.items()
if k != "_builder"}, _builder=builder))
setattr(obj, name, new_member)
@@ -384,8 +395,8 @@ def _patch_lang_core(lang, builder):
def _new_reduce(input, axis, combine_fn):
fn = combine_fn.fn.__name__
mapping = {
'maximum': np.max,
'_sum_combine': np.sum,
"maximum": np.max,
"_sum_combine": np.sum,
}
ret = mapping[fn](input.handle.data, axis=axis)
ret_type = tl.block_type(input.dtype, ret.shape)
@@ -397,15 +408,16 @@ def _patch_lang_core(lang, builder):
def _patch_lang_math(lang, builder):
math = lang.math
mapping = {
'abs': 'abs',
'acos': 'arccos',
'asin': 'arcsin',
'exp2': 'exp2',
'log2': 'log2',
'max': 'maximum',
"abs": "abs",
"acos": "arccos",
"asin": "arcsin",
"exp2": "exp2",
"log2": "log2",
"max": "maximum",
}
def make_numpy(name):
def impl(*args, **kwargs):
ret_type = args[0].type # TODO: incorrect
ret_dtype = args[0].dtype # TODO: incorrect
@@ -414,15 +426,18 @@ def _patch_lang_math(lang, builder):
ret = getattr(np, mapping[name])(*args, **kwargs)
ret = tl.core.tensor(TensorHandle(ret, ret_dtype), ret_type)
return ret
return impl
def make_fallback(name):
def fallback(*args, **kwargs):
raise NotImplementedError(f"""
{name} not supported in interpreter mode: no known numpy implementation.
If you think that {name} in fact does have a numpy implementation, please add it
to the mapping in python/triton/interpreter/new_interpreter.py:_patch_lang_math.
""")
return fallback
for name, member in inspect.getmembers(math):
@@ -438,7 +453,7 @@ def _implicit_cvt(arg):
ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
handle = TensorHandle(np.array([arg], dtype=np.int32), ty)
return tl.tensor(handle, ty)
if hasattr(arg, 'data_ptr'):
if hasattr(arg, "data_ptr"):
ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty)
return tl.tensor(handle, ty)
@@ -453,28 +468,29 @@ def _unwrap(tensor):
builder = Builder()
RESERVED_KWS = ['num_warps', 'num_stages', 'num_ctas', 'enable_warp_specialization', 'enable_fp_fusion']
RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_warp_specialization", "enable_fp_fusion"]
class GridExecutor:
def __init__(self, fn, arg_names, grid):
from .jit import _normalize_ty # TODO: modularize
self.fn = fn
self.arg_names = arg_names
self.grid = grid
__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
self.constexprs = [name for name in arg_names if __annotations__.get(name) == 'constexpr']
self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"]
def _patch_lang(self, builder):
lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]]
assert len(lang) == 1, "triton.language must be visible from within jit'd function"
_patch_lang_tensor(getattr(lang[0], 'tensor'), builder)
_patch_lang_tensor(getattr(lang[0], "tensor"), builder)
_patch_lang_core(lang[0], builder)
_patch_lang_math(lang[0], builder)
def __call__(self, *args_dev, **kwargs):
args_hst = [_unwrap(arg).cpu() if hasattr(arg, 'data_ptr') else arg for arg in args_dev]
args_hst = [_unwrap(arg).cpu() if hasattr(arg, "data_ptr") else arg for arg in args_dev]
# removes reserved keywords from kwargs
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS}
# remaps core language functions to interpreted ones
@@ -486,7 +502,7 @@ class GridExecutor:
# iterate through grid
grid = self.grid(args) if callable(self.grid) else self.grid
assert len(grid) <= 3
grid = grid + (1,) * (3 - len(grid))
grid = grid + (1, ) * (3 - len(grid))
builder.set_grid_dim(*grid)
for x in range(grid[0]):
for y in range(grid[1]):
@@ -495,7 +511,7 @@ class GridExecutor:
self.fn(**args)
# copy arguments back to propagate side-effects
for arg_dev, arg_hst in zip(args_dev, args_hst):
if hasattr(arg_dev, 'data_ptr'):
if hasattr(arg_dev, "data_ptr"):
_unwrap(arg_dev).copy_(arg_hst.to(arg_dev.device))
@@ -504,17 +520,18 @@ class InterpretedFunction:
def _patch_lang(self, builder):
lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]]
assert len(lang) == 1, "triton.language must be visible from within jit'd function"
_patch_lang_tensor(getattr(lang[0], 'tensor'), builder)
_patch_lang_tensor(getattr(lang[0], "tensor"), builder)
_patch_lang_core(lang[0], builder)
def __init__(self, fn) -> None:
self.fn = fn
def run(*args, **kwargs):
grid = kwargs['grid']
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS + ['grid']}
grid = kwargs["grid"]
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS + ["grid"]}
return GridExecutor(self.fn, self.arg_names, grid)(*args, **kwargs)
self.run = run
signature = inspect.signature(fn)
self.arg_names = [v.name for v in signature.parameters.values()]

View File

@@ -5,48 +5,48 @@ import functools
import hashlib
import inspect
import os
import subprocess
import textwrap
from collections import defaultdict, namedtuple
from typing import (Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast,
overload)
from functools import cached_property
from typing import Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast, overload
from .._C.libtriton.triton import TMAInfos
from ..common.backend import get_backend, path_to_ptxas
from ..language.core import dtype
from ..common.backend import get_backend, get_cuda_version_key
from .interpreter import InterpretedFunction
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
TRITON_VERSION = "2.1.0"
def get_cuda_stream(idx=None):
if idx is None:
idx = get_current_device()
try:
from torch._C import _cuda_getCurrentRawStream
return _cuda_getCurrentRawStream(idx)
except ImportError:
import torch
return torch.cuda.current_stream(idx).cuda_stream
def get_current_device():
import torch
return torch.cuda.current_device()
def set_current_device(idx):
import torch
torch.cuda.set_device(idx)
def get_device_capability(idx):
import torch
return torch.cuda.get_device_capability(idx)
T = TypeVar('T')
T = TypeVar("T")
# -----------------------------------------------------------------------------
# Dependencies Finder
@@ -72,7 +72,8 @@ class DependenciesFinder(ast.NodeVisitor):
lhs = self.visit(node.value)
while isinstance(lhs, ast.Attribute):
lhs = self.visit(lhs.value)
if lhs is None or (getattr(lhs, "__name__", "") == "triton" or getattr(lhs, "__name__", "").endswith(".triton")):
if lhs is None or (getattr(lhs, "__name__", "") == "triton"
or getattr(lhs, "__name__", "").endswith(".triton")):
return None
return getattr(lhs, node.attr)
@@ -82,55 +83,26 @@ class DependenciesFinder(ast.NodeVisitor):
return
if inspect.isbuiltin(func):
return
if func.__module__ and (func.__module__.startswith('triton.') or '.triton.' in func.__module__):
if func.__module__ and (func.__module__.startswith("triton.") or ".triton." in func.__module__):
return
assert isinstance(func, JITFunction), f"Function \"{func.__name__}\" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this"
assert isinstance(
func, JITFunction
), f'Function "{func.__name__}" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this'
if func.hash is None:
tree = ast.parse(func.src)
finder = DependenciesFinder(func.__globals__, func.src)
finder.visit(tree)
func.hash = finder.ret
noinline = str(getattr(func, 'noinline', False))
noinline = str(getattr(func, "noinline", False))
self.ret = (self.ret + func.hash + noinline).encode("utf-8")
self.ret = hashlib.sha1(self.ret).hexdigest()
# -----------------------------------------------------------------------------
# JITFunction
# -----------------------------------------------------------------------------
@functools.lru_cache()
def version_key():
import pkgutil
contents = []
# frontend
with open(__file__, "rb") as f:
contents += [hashlib.sha1(f.read()).hexdigest()]
# compiler
compiler_path = os.path.join(TRITON_PATH, 'compiler')
for lib in pkgutil.iter_modules([compiler_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.sha1(f.read()).hexdigest()]
# backend
libtriton_hash = hashlib.sha1()
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
while True:
chunk = f.read(1024 ** 2)
if not chunk:
break
libtriton_hash.update(chunk)
contents.append(libtriton_hash.hexdigest())
# language
language_path = os.path.join(TRITON_PATH, 'language')
for lib in pkgutil.iter_modules([language_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.sha1(f.read()).hexdigest()]
# ptxas version
ptxas = path_to_ptxas()[0]
ptxas_version = hashlib.sha1(subprocess.check_output([ptxas, "--version"])).hexdigest()
return '-'.join(TRITON_VERSION) + '-' + ptxas_version + '-' + '-'.join(contents)
def _normalize_ty(ty) -> str:
if isinstance(ty, type):
return ty.__name__
@@ -139,6 +111,85 @@ def _normalize_ty(ty) -> str:
return repr(ty)
class KernelParam:
"""Represents a parameter to a @jit'ed function.
A parameter is just the name plus metadata; a parameter plus a value is a
KernelArg.
"""
def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool):
self.num = num
self._param = param
self.do_not_specialize = do_not_specialize
@cached_property
def name(self):
return self._param.name
@cached_property
def annotation(self):
if not self._param.annotation or self._param.annotation == inspect.Parameter.empty:
return ""
return _normalize_ty(self._param.annotation)
@cached_property
def is_constexpr(self):
return "constexpr" in self.annotation
@property
def default(self):
return self._param.default
@property
def has_default(self):
return self._param.default != inspect.Parameter.empty
class KernelArg:
"""Represents an argument to a @jit'ed function.
An argument is a parameter plus a value.
"""
def __init__(self, value, param):
self.value = value
self.param = param
@property
def name(self):
return self.param.name
def signature_key(self):
annotation = self.param.annotation
if "Tensor" in annotation:
return self.value.dtype
elif annotation == "bool":
return "i1"
elif annotation == "float":
return "fp32"
else:
return JITFunction._key_of(self.value)
def specialization_key(self):
assert not self.param.do_not_specialize
try:
return (self.value.data_ptr() % JITFunction.divisibility == 0, )
except AttributeError:
pass
if isinstance(self.value, int):
# bool is a subclass of int, so we don't check explicitly above.
return (
self.value % JITFunction.divisibility == 0,
self.value % JITFunction.divisibility_8 == 0,
self.value == 1,
)
return (False, )
class KernelInterface(Generic[T]):
run: T
@@ -152,7 +203,6 @@ class KernelInterface(Generic[T]):
class JITFunction(KernelInterface[T]):
# Hook for inspecting compiled functions and modules
cache_hook = None
divisibility = 16
@@ -169,44 +219,44 @@ class JITFunction(KernelInterface[T]):
elif isinstance(arg, bool):
return "i1"
elif isinstance(arg, int):
if -2**31 <= arg and arg <= 2**31 - 1:
if -(2**31) <= arg and arg <= 2**31 - 1:
return "i32"
elif 2**63 <= arg and arg <= 2**64 - 1:
return "u64"
else:
return "i64"
elif isinstance(arg, float):
return 'fp32'
return "fp32"
elif arg is None:
return None
else:
raise TypeError(f'Unsupported type {type(arg)} for {arg}')
raise TypeError(f"Unsupported type {type(arg)} for {arg}")
@staticmethod
def _device_of(arg):
if hasattr(arg, "device"):
if hasattr(arg.device, 'type'):
return arg.device.type
return ''
try:
return arg.device.type
except AttributeError:
return ""
@staticmethod
def _pinned_memory_of(arg):
if hasattr(arg, "is_pinned"):
if isinstance(arg.is_pinned, Callable):
return arg.is_pinned()
return False
try:
return arg.is_pinned()
except (AttributeError, TypeError):
return False
@staticmethod
def _spec_of(arg):
if hasattr(arg, "data_ptr"):
return (arg.data_ptr() % JITFunction.divisibility == 0)
return arg.data_ptr() % JITFunction.divisibility == 0
elif isinstance(arg, int):
return (arg % 16 == 0, arg == 1)
return (arg is None, )
# TODO(jlebar): Fold this into the KernelArg class.
def _get_config(self, *args):
def is_divisible_by_16(x):
if hasattr(x, "data_ptr"):
return x.data_ptr() % JITFunction.divisibility == 0
@@ -222,28 +272,38 @@ class JITFunction(KernelInterface[T]):
if x is None:
return True
return False
divisible_by_16 = {i for i, arg in enumerate(
args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
divisible_by_8 = {i for i, arg in enumerate(
args) if is_divisible_by_8(arg) and i not in self.do_not_specialize}
divisible_by_16 = {
param.num
for param, arg in zip(self.params, args)
if is_divisible_by_16(arg) and not param.do_not_specialize
}
divisible_by_8 = {
param.num
for param, arg in zip(self.params, args)
if is_divisible_by_8(arg) and not param.do_not_specialize
}
equal_to_1 = {
i for i, arg in enumerate(args) if isinstance(
arg, int) and not isinstance(
arg, bool) and arg == 1 and i not in self.do_not_specialize}
param.num
for param, arg in zip(self.params, args)
if isinstance(arg, int) and not isinstance(arg, bool) and arg == 1 and not param.do_not_specialize
}
# folded equal_to_1 and None
# TODO: method to collect all folded args
none_args = {i for i, arg in enumerate(args) if arg is None and i not in self.do_not_specialize}
none_args = {param.num for param, arg in zip(self.params, args) if arg is None and not param.do_not_specialize}
ids_of_folded_args = equal_to_1 | none_args
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])(
tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args), tuple(divisible_by_8))
return namedtuple("instance_descriptor",
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])( #
tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args),
tuple(divisible_by_8))
# return _triton.code_gen.instance_descriptor(divisible_by_16,
# equal_to_1)
@staticmethod
def _type_of(key):
# None are nullptr -- implicitly converted to *i8
# `None` is nullptr. Implicitly convert to *i8.
if key is None:
return '*i8'
return "*i8"
dtype_str = str(key).split(".")[-1]
tys = {
"bool": "i1",
@@ -281,21 +341,46 @@ class JITFunction(KernelInterface[T]):
constants = dict(zip(self.constexprs, constexpr_key))
return constants
<<<<<<< HEAD
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization,enable_fp_fusion, extern_libs, configs):
=======
def _call_hook(
self,
key,
signature,
device,
constants,
num_warps,
num_ctas,
num_stages,
enable_warp_specialization,
enable_fp_fusion,
extern_libs,
configs,
):
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
if JITFunction.cache_hook is None:
return False
name = self.fn.__name__
module = self.fn.__module__
<<<<<<< HEAD
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, matrix_instr_nonkdim={matrix_instr_nonkdim}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs}), enable_fp_fusion={enable_fp_fusion}]({arg_reprs})"
=======
arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])])
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, enable_warp_specialization={enable_warp_specialization}, enable_fp_fusion={enable_fp_fusion}]({arg_reprs})"
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
key = str(key)
class LegacyCompiler:
def __init__(self, module, name):
self.module = module
self.name = name
pass
<<<<<<< HEAD
kwargs = dict(signature=signature, device=device, constants=constants,
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs,
configs=configs)
@@ -326,18 +411,43 @@ class JITFunction(KernelInterface[T]):
return 'fp32'
else:
return self._key_of(arg)
=======
kwargs = dict(
signature=signature,
device=device,
constants=constants,
num_warps=num_warps,
num_ctas=num_ctas,
num_stages=num_stages,
enable_warp_specialization=enable_warp_specialization,
enable_fp_fusion=enable_fp_fusion,
extern_libs=extern_libs,
configs=configs,
)
return JITFunction.cache_hook(
key=key,
repr=repr,
fn=LegacyCompiler(module, name),
compile={"key": key, **kwargs},
is_manual_warmup=False,
already_compiled=False,
)
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
def _conclude_device_type(self, device_types: List[str], pinned_memory_flags: List[bool]) -> str:
device_types = [device_type for device_type in device_types if device_type != '']
device_types = [device_type for device_type in device_types if device_type != ""]
# Return cuda if one of the input tensors is cuda
if 'cuda' in device_types:
if "cuda" in device_types:
import torch
return 'hip' if torch.version.hip else 'cuda'
is_cpu = all(device_type == 'cpu' for device_type in device_types)
return "hip" if torch.version.hip else "cuda"
is_cpu = all(device_type == "cpu" for device_type in device_types)
is_pinned_memory = any(pinned_memory_flag for pinned_memory_flag in pinned_memory_flags)
# Return cuda if all the input tensors are cpu while the memory is pinned
if is_cpu and is_pinned_memory:
<<<<<<< HEAD
return 'cuda'
return device_types[0] if len(device_types) > 0 else 'cuda'
@@ -452,16 +562,193 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
scope = {"launcher_body": launcher_body}
exec(src, scope)
return scope[self.fn.__name__]
=======
return "cuda"
return device_types[0] if len(device_types) > 0 else "cuda"
def run(self, *args, **kwargs):
from ..compiler import CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps
# Get a compiler-flags arg like `num_warps` and remove it from kwargs.
def get_special_arg(name: str, default=None):
if name not in kwargs:
return default
ret = kwargs[name]
del kwargs[name]
return ret
grid = get_special_arg("grid")
num_warps = get_special_arg("num_warps")
num_ctas = get_special_arg("num_ctas", 1)
num_stages = get_special_arg("num_stages")
enable_warp_specialization = get_special_arg("enable_warp_specialization", False)
enable_fp_fusion = get_special_arg("enable_fp_fusion", True)
extern_libs = get_special_arg("extern_libs")
stream = get_special_arg("stream")
warmup = get_special_arg("warmup", False)
device = get_special_arg("device")
device_type = get_special_arg("device_type")
# Bind the remaining arguments to `fn`.
bound_args = self.signature.bind(*args, **kwargs)
bound_args.apply_defaults()
assert len(bound_args.arguments) == len(self.params)
args = [KernelArg(arg_value, param) for (_, arg_value), param in zip(bound_args.arguments.items(), self.params)]
non_constexpr_arg_values = [arg.value for arg in args if not arg.param.is_constexpr]
sig_key = tuple(arg.signature_key() for arg in args if not arg.param.is_constexpr)
spec_key = tuple(arg.specialization_key() for arg in args if not arg.param.do_not_specialize)
constexpr_key = tuple(arg.value for arg in args if arg.param.is_constexpr)
assert num_ctas > 0
assert grid is not None
if callable(grid):
# Arguments are passed as a dict to `grid`, by contract.
# TODO(jlebar): In the new launch API, pass the compiler flags as a
# second parameter to `grid`.
grid = grid(dict(bound_args.arguments))
grid_size = len(grid)
grid_0 = grid[0]
grid_1 = grid[1] if grid_size > 1 else 1
grid_2 = grid[2] if grid_size > 2 else 1
if device_type is None:
device_types = [self._device_of(arg) for arg in non_constexpr_arg_values]
device_types = [_device_type for _device_type in device_types if _device_type != ""]
device_type = self._conclude_device_type(device_types,
[self._pinned_memory_of(arg) for arg in non_constexpr_arg_values])
device_backend = None
if device_type not in ["cuda"]:
device_backend = get_backend(device_type)
if device_backend is None:
raise ValueError("Cannot find backend for " + device_type)
if device is None:
if device_type in ["cuda"]:
device = get_current_device()
set_current_device(device)
else:
device = device_backend.get_current_device()
device_backend.set_current_device(device)
if stream is None and not warmup:
if device_type in ["cuda"]:
stream = get_cuda_stream(device)
else:
stream = device_backend.get_stream()
if num_warps is None:
num_warps = get_arch_default_num_warps(device_type)
if num_stages is None:
num_stages = get_arch_default_num_stages(device_type)
if device_type in ["cuda"]:
version_key = get_cuda_version_key()
else:
version_key = device_backend.get_version_key()
key = (
version_key,
sig_key,
constexpr_key,
spec_key,
num_warps,
num_ctas,
num_stages,
enable_warp_specialization,
enable_fp_fusion,
self.debug,
)
if extern_libs is not None:
key = (key, tuple(extern_libs.items()))
# Kernel is not cached; we have to compile.
if key not in self.cache[device]:
configs = (self._get_config(*[arg.value for arg in args]), )
constants = {
arg.param.num: arg.value
for arg in args
if arg.param.is_constexpr or arg.param.num in configs[0].equal_to_1 or arg.value is None
}
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {i} is not supported")
# Build kernel signature -- doesn't include constexpr arguments.
signature = {
arg.param.num: self._type_of(self._key_of(arg.value))
for arg in args
if not arg.param.is_constexpr
}
if self._call_hook(
key,
signature,
device,
constants,
num_warps,
num_ctas,
num_stages,
enable_warp_specialization,
enable_fp_fusion,
extern_libs,
configs,
):
return None
self.cache[device][key] = compile(
self,
signature=signature,
device=device,
constants=constants,
num_warps=num_warps,
num_ctas=num_ctas,
num_stages=num_stages,
enable_warp_specialization=enable_warp_specialization,
enable_fp_fusion=enable_fp_fusion,
extern_libs=extern_libs,
configs=configs,
debug=self.debug,
device_type=device_type,
)
bin = self.cache[device][key]
if not warmup:
bin.c_wrapper(
grid_0,
grid_1,
grid_2,
bin.num_warps,
bin.num_ctas,
bin.clusterDims[0],
bin.clusterDims[1],
bin.clusterDims[2],
bin.shared,
stream,
bin.cu_function,
CompiledKernel.launch_enter_hook,
CompiledKernel.launch_exit_hook,
bin,
*bin.assemble_tensormap_to_arg(non_constexpr_arg_values),
)
return bin
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None):
do_not_specialize = do_not_specialize if do_not_specialize else []
self.fn = fn
self.module = fn.__module__
self.version = version
# function signature information
signature = inspect.signature(fn)
self.arg_names = [v.name for v in signature.parameters.values()]
self.arg_defaults = [v.default for v in signature.parameters.values()]
self.has_defaults = any(v != inspect._empty for v in self.arg_defaults)
self.signature = inspect.signature(fn)
self.do_not_specialize = do_not_specialize
self.params = []
for i, param in enumerate(self.signature.parameters.values()):
dns = do_not_specialize and (i in do_not_specialize or param.name in do_not_specialize)
self.params.append(KernelParam(i, param, dns))
# function source code (without decorators)
self.src = textwrap.dedent(inspect.getsource(fn))
self.src = self.src[self.src.find("def"):]
@@ -470,22 +757,18 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
self.hash = None
# JITFunction can be instantiated as kernel
# when called with a grid using __getitem__
self.kernel_decorators = []
self.kernel = None
self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug
self.noinline = noinline
# annotations
self.__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
# index of constexprs
self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if 'constexpr' in ty]
# specialization hints
regular_args = [arg for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
self.do_not_specialize = {regular_args.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize}
# tma info
self.tensormaps_info = TMAInfos()
# launcher
self.run = self._make_launcher()
# TODO(jlebar): Remove uses of these fields outside this file, then
# remove the fields here.
self.arg_names = [p.name for p in self.params]
self.constexprs = [p.num for p in self.params if p.is_constexpr]
# re-use docs of wrapped function
self.__doc__ = fn.__doc__
self.__name__ = fn.__name__
@@ -498,7 +781,7 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
if self.hash is None:
dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src)
dependencies_finder.visit(self.parse())
self.hash = dependencies_finder.ret + version_key()
self.hash = dependencies_finder.ret
return self.hash
def warmup(self, *args, **kwargs):
@@ -518,14 +801,10 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
def __setattr__(self, name, value):
# - when kernel decorators change, cached kernel
# needs to be cleared
if name == 'kernel_decorators':
self.kernel = None
super(JITFunction, self).__setattr__(name, value)
# - when `.src` attribute is set, cache path needs
# to be reinitialized
if name == 'src':
if name == "src":
self.hash = None
def __repr__(self):
@@ -591,12 +870,14 @@ def jit(
debug=debug,
noinline=noinline,
)
if fn is not None:
return decorator(fn)
else:
return decorator
# -----------------------------------------------------------------------------
# Utilities for mocking tensors
# -----------------------------------------------------------------------------
@@ -607,10 +888,10 @@ class MockTensor:
Can be used in place of real tensors when calling:
kernel.warmup(MockTensor(torch.float32), ...)
"""
@staticmethod
def wrap_dtype(arg):
if arg.__class__.__name__ == "dtype" and\
arg.__module__ == "torch":
if arg.__class__.__name__ == "dtype" and arg.__module__ == "torch":
return MockTensor(arg)
return arg
@@ -623,6 +904,7 @@ class MockTensor:
class TensorWrapper:
def __init__(self, base, dtype):
self.dtype = dtype
self.base = base
@@ -637,7 +919,7 @@ class TensorWrapper:
return self.base.stride(i)
def __str__(self) -> str:
return f'TensorWrapper[{self.dtype}]({self.base})'
return f"TensorWrapper[{self.dtype}]({self.base})"
def element_size(self):
return self.base.element_size()
@@ -655,4 +937,4 @@ def reinterpret(tensor, dtype):
# A new wrapper is needed around an unwrapped tensor.
return TensorWrapper(tensor, dtype)
else:
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')
raise TypeError(f"Cannot reinterpret a {type(tensor)}.")

View File

@@ -78,10 +78,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None):
return torch.mean(torch.tensor(ret)).item()
def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
quantiles=None,
fast_flush=True,
return_mode="mean"):
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean"):
assert return_mode in ["min", "max", "mean", "median"]
import torch
"""
@@ -261,11 +258,12 @@ class Benchmark:
class Mark:
def __init__(self, fn, benchmarks):
self.fn = fn
self.benchmarks = benchmarks
def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, **kwrags):
def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False, **kwrags):
import os
import matplotlib.pyplot as plt
@@ -321,24 +319,36 @@ class Mark:
if save_path:
plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
df = df[x_names + bench.line_names]
if diff_col and df.shape[1] == 2:
col0, col1 = df.columns.tolist()
df['Diff'] = df[col1] - df[col0]
if print_data:
print(bench.plot_name + ':')
print(df)
if save_path:
df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False)
return df
def run(self, show_plots=False, print_data=False, save_path='', **kwargs):
def run(self, show_plots=False, print_data=False, save_path='', return_df=False, **kwargs):
has_single_bench = isinstance(self.benchmarks, Benchmark)
benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks
result_dfs = []
if save_path:
html = open(os.path.join(save_path, "results.html"), "w")
html.write("<html><body>\n")
for bench in benchmarks:
self._run(bench, save_path, show_plots, print_data, **kwargs)
result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs))
if save_path:
html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
if save_path:
html.write("</body></html>\n")
if return_df:
if has_single_bench:
return result_dfs[0]
else:
return result_dfs
return None
def perf_report(benchmarks):
@@ -393,12 +403,15 @@ def get_max_tensorcore_tflops(dtype, clock_rate, backend=None, device=None):
tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
return tflops
# create decorator that wraps test function into
# a cuda-memcheck system call
def cuda_memcheck(**target_kwargs):
def decorator(test_fn):
@functools.wraps(test_fn)
def wrapper(*args, **kwargs):
import psutil
@@ -416,7 +429,9 @@ def cuda_memcheck(**target_kwargs):
assert "ERROR SUMMARY: 0 errors" in str(out.stdout)
else:
test_fn(*args, **kwargs)
return wrapper
return decorator
@@ -424,22 +439,18 @@ def cuda_memcheck(**target_kwargs):
def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215):
try:
subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"])
subprocess.check_output(
[
"nvidia-smi",
"-i",
"0",
f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}",
]
)
subprocess.check_output(
[
"nvidia-smi",
"-i",
"0",
f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}",
]
)
subprocess.check_output([
"nvidia-smi",
"-i",
"0",
f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}",
])
subprocess.check_output([
"nvidia-smi",
"-i",
"0",
f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}",
])
cur_sm_clock = nvsmi(["clocks.current.sm"])[0]
cur_mem_clock = nvsmi(["clocks.current.memory"])[0]
assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz"

View File

@@ -141,8 +141,7 @@ class ExternLibrary(ABC):
f.write(file_str)
f.close()
if self._format:
subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file],
stdout=subprocess.PIPE).communicate()
subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file], stdout=subprocess.PIPE).communicate()
subprocess.Popen(["isort", output_file], stdout=subprocess.PIPE).communicate()
@@ -208,56 +207,36 @@ class Libdevice(ExternLibrary):
# Group functions together by renaming.
renaming = {
'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh',
'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': 'add_rn',
'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru',
'dadd_rz': 'add_rz', 'fadd_rz': 'add_rz', 'asinf': 'asin',
'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2',
'atanhf': 'atanh', 'brevll': 'brev', 'cbrtf': 'cbrt',
'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign',
'cosf': 'cos', 'coshf': 'cosh', 'cospif': 'cospi',
'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1',
'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn',
'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', 'ddiv_ru': 'div_ru',
'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf',
'erfcf': 'erfc', 'erfcinvf': 'erfcinv', 'erfcxf': 'erfcx',
'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10',
'exp2f': 'exp2', 'expm1f': 'expm1', 'fabsf': 'abs',
'fabs': 'abs', 'fast_fdividef': 'fast_dividef',
'fdimf': 'fdim', 'ffsll': 'ffs', 'floorf': 'floor',
'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn',
'fmaf_ru': 'fma_ru', 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod',
'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb',
'isinff': 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan',
'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn',
'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint',
'llroundf': 'llround', 'logf': 'log', 'log10f': 'log10',
'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb',
'umax': 'max', 'llmax': 'max', 'ullmax': 'max', 'fmaxf': 'max',
'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min',
'fminf': 'min', 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd',
'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', 'dmul_ru': 'mul_ru',
'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz',
'umul24': 'mul24', 'umulhi': 'mulhi', 'mul64hi': 'mulhi',
'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': 'nextafter',
'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf',
'normcdfinvf': 'normcdfinv', 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow',
'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', 'drcp_rd': 'rcp_rd',
'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru',
'drcp_ru': 'rcp_ru', 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz',
'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot',
'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d',
'roundf': 'round', 'rsqrtf': 'rsqrt', 'frsqrt_rn': 'rsqrt_rn',
'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit',
'signbitd': 'signbit', 'sinf': 'sin', 'sinhf': 'sinh',
'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd',
'dsqrt_rd': 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn',
'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', 'fsqrt_rz': 'sqrt_rz',
'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd',
'fsub_rn': 'sub_rn', 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru',
'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz',
'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc',
'y0f': 'y0', 'y1f': 'y1', 'ynf': 'yn'
'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh', 'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn':
'add_rn', 'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru', 'dadd_rz': 'add_rz', 'fadd_rz':
'add_rz', 'asinf': 'asin', 'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2', 'atanhf': 'atanh',
'brevll': 'brev', 'cbrtf': 'cbrt', 'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign', 'cosf': 'cos',
'coshf': 'cosh', 'cospif': 'cospi', 'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1',
'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn', 'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru',
'ddiv_ru': 'div_ru', 'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf', 'erfcf': 'erfc', 'erfcinvf':
'erfcinv', 'erfcxf': 'erfcx', 'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10', 'exp2f': 'exp2',
'expm1f': 'expm1', 'fabsf': 'abs', 'fabs': 'abs', 'fast_fdividef': 'fast_dividef', 'fdimf': 'fdim', 'ffsll':
'ffs', 'floorf': 'floor', 'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn', 'fmaf_ru': 'fma_ru',
'fmaf_rz': 'fma_rz', 'fmodf': 'fmod', 'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb', 'isinff':
'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan', 'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn',
'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint', 'llroundf': 'llround', 'logf': 'log', 'log10f':
'log10', 'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb', 'umax': 'max', 'llmax': 'max', 'ullmax':
'max', 'fmaxf': 'max', 'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min', 'fminf': 'min',
'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd', 'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn',
'dmul_ru': 'mul_ru', 'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz', 'umul24': 'mul24',
'umulhi': 'mulhi', 'mul64hi': 'mulhi', 'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf':
'nextafter', 'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf', 'normcdfinvf': 'normcdfinv',
'popcll': 'popc', 'powif': 'pow', 'powi': 'pow', 'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd',
'drcp_rd': 'rcp_rd', 'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru', 'drcp_ru': 'rcp_ru',
'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz', 'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot',
'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d', 'roundf': 'round', 'rsqrtf': 'rsqrt',
'frsqrt_rn': 'rsqrt_rn', 'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit', 'signbitd': 'signbit',
'sinf': 'sin', 'sinhf': 'sinh', 'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd', 'dsqrt_rd':
'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn', 'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru',
'fsqrt_rz': 'sqrt_rz', 'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd', 'fsub_rn': 'sub_rn',
'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru', 'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz',
'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc', 'y0f': 'y0', 'y1f': 'y1', 'ynf':
'yn'
}
for symbol in self._symbols.values():
@@ -347,8 +326,7 @@ class LLVMDisassembler:
self._ll_file = "/tmp/extern_lib.ll"
def disasm(self, lib_path: str) -> None:
subprocess.Popen([self._path, lib_path, "-o", self.ll_file],
stdout=subprocess.PIPE).communicate()
subprocess.Popen([self._path, lib_path, "-o", self.ll_file], stdout=subprocess.PIPE).communicate()
@property
def ll_file(self) -> str:

View File

@@ -40,10 +40,13 @@ if __name__ == "__main__":
# command-line arguments
parser = ArgumentParser(description=desc)
parser.add_argument("path", help="Path to Python source containing desired kernel in its scope. File will be executed.")
parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile", required=True)
parser.add_argument("path",
help="Path to Python source containing desired kernel in its scope. File will be executed.")
parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile",
required=True)
parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel")
parser.add_argument("--num-stages", "-ns", type=int, default=3, help="Number of stages (meta-parameter of the kernel)")
parser.add_argument("--num-stages", "-ns", type=int, default=3,
help="Number of stages (meta-parameter of the kernel)")
parser.add_argument("--out-name", "-on", type=str, default=None, help="Out name for the compiled kernel")
parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename")
parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True)
@@ -104,7 +107,8 @@ if __name__ == "__main__":
config = triton.compiler.instance_descriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1)
for i in equal_to_1:
constexprs.update({i: 1})
ccinfo = triton.compile(kernel, signature=signature, constants=constexprs, configs=[config], num_warps=args.num_warps, num_stages=args.num_stages)
ccinfo = triton.compile(kernel, signature=signature, constants=constexprs, configs=[config],
num_warps=args.num_warps, num_stages=args.num_stages)
arg_names = []
arg_types = []
for i in signature.keys():

View File

@@ -27,6 +27,7 @@ class KernelLinkerMeta:
class HeaderParser:
def __init__(self) -> None:
import re
@@ -42,7 +43,6 @@ class HeaderParser:
self.kernels = defaultdict(list)
def extract_linker_meta(self, header: str):
for ln in header.splitlines():
if ln.startswith("//"):
m = self.linker_directives.match(ln)
@@ -76,7 +76,7 @@ class HeaderParser:
m = self.c_sig.findall(c_sig)
if len(m):
tys, args = [], []
for (ty, arg_name) in m:
for ty, arg_name in m:
tys.append(ty)
args.append(arg_name)
return tys, args
@@ -84,7 +84,7 @@ class HeaderParser:
raise LinkerError(f"{c_sig} is not a valid argument signature")
def _match_suffix(self, suffix: str, c_sig: str):
args = c_sig.split(',')
args = c_sig.split(",")
s2i = {"c": 1, "d": 16}
num_specs = 0
sizes = []
@@ -110,7 +110,7 @@ class HeaderParser:
if name in self.kernels:
last: KernelLinkerMeta = self.kernels[name][-1]
for (cur, new_) in zip(last.arg_ctypes, ker.arg_ctypes):
for cur, new_ in zip(last.arg_ctypes, ker.arg_ctypes):
if cur != new_:
raise LinkerError(
f"Mismatched signature for kernel {name}: \n\texisting sig is: {','.join(last.arg_ctypes)}\n\tcurrent is: {','.join(ker.arg_ctypes)}"
@@ -152,7 +152,7 @@ void unload_{meta.orig_kernel_name}();
# generate dispatcher function for kernels with different meta-parameter and constant values
def make_default_algo_kernel(meta: KernelLinkerMeta) -> str:
src = f"CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}){{\n"
src += f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n"
src += (f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n")
src += "}\n"
return src
@@ -164,12 +164,22 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -
src += f"CUresult {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, {gen_signature(meta)});\n"
src += "\n"
src += f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{"
src += (f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{")
src += "\n"
for meta in sorted(metas, key=lambda m: -m.num_specs):
cond_fn = lambda val, hint: f"({val} % {hint} == 0)" if hint == 16 else f"({val} == {hint})" if hint == 1 else None
conds = " && ".join([cond_fn(val, hint) for val, hint in zip(meta.arg_names, meta.sizes) if hint is not None])
src += f" if ({conds})\n"
cond_fn = ( #
lambda val, hint: f"({val} % {hint} == 0)" #
if hint == 16 #
else f"({val} == {hint})" #
if hint == 1 #
else None)
conds = " && ".join([ #
cond_fn(val, hint) #
for val, hint in zip(meta.arg_names, meta.sizes) #
if hint is not None
])
src += (f" if ({conds})\n" if any(meta.sizes) else "if (1)\n"
) # Edge case where no specializations hence no dispatching required
arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1]
src += f" return {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(stream, {', '.join(arg_names)});\n"
src += "\n"
@@ -183,7 +193,7 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -
src += f"void {mode}_{name}() {{"
src += "\n"
for meta in sorted(metas, key=lambda m: -m.num_specs):
src += f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n"
src += (f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n")
src += "}\n"
return src
@@ -252,7 +262,12 @@ if __name__ == "__main__":
help="Paths to header files to link. Must include linker directive annotations (autogenerated by ttc)",
)
parser.add_argument("--out", "-o", type=Path, help="Out filename")
parser.add_argument("--prefix", type=str, default="", help="String to prefix kernel dispatcher names")
parser.add_argument(
"--prefix",
type=str,
default="",
help="String to prefix kernel dispatcher names",
)
args = parser.parse_args()
# metadata

View File

@@ -25,14 +25,13 @@ import triton.language as tl
@triton.jit
def add_kernel(
x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
output_ptr, # *Pointer* to output vector.
n_elements, # Size of the vector.
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
# NOTE: `constexpr` so it can be used as a shape value.
):
def add_kernel(x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
output_ptr, # *Pointer* to output vector.
n_elements, # Size of the vector.
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
# NOTE: `constexpr` so it can be used as a shape value.
):
# There are multiple 'programs' processing different data. We identify which program
# we are here:
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
@@ -66,7 +65,7 @@ def add(x: torch.Tensor, y: torch.Tensor):
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
# In this case, we use a 1D grid where the size is the number of blocks:
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
# NOTE:
# - Each torch.tensor object is implicitly converted into a pointer to its first element.
# - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel.
@@ -88,10 +87,8 @@ output_torch = x + y
output_triton = add(x, y)
print(output_torch)
print(output_triton)
print(
f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}'
)
print(f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}')
# %%
# Seems like we're good to go!
@@ -108,9 +105,7 @@ print(
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['size'], # Argument names to use as an x-axis for the plot.
x_vals=[
2 ** i for i in range(12, 28, 1)
], # Different possible values for `x_name`.
x_vals=[2**i for i in range(12, 28, 1)], # Different possible values for `x_name`.
x_log=True, # x axis is logarithmic.
line_arg='provider', # Argument name whose value corresponds to a different line in the plot.
line_vals=['triton', 'torch'], # Possible values for `line_arg`.
@@ -119,8 +114,7 @@ print(
ylabel='GB/s', # Label name for the y-axis.
plot_name='vector-add-performance', # Name for the plot. Used also as a file name for saving the plot.
args={}, # Values for function arguments not in `x_names` and `y_name`.
)
)
))
def benchmark(size, provider):
x = torch.rand(size, device='cuda', dtype=torch.float32)
y = torch.rand(size, device='cuda', dtype=torch.float32)

View File

@@ -71,10 +71,7 @@ def naive_softmax(x):
@triton.jit
def softmax_kernel(
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
BLOCK_SIZE: tl.constexpr
):
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr):
# The rows of the softmax are independent, so we parallelize across those
row_idx = tl.program_id(0)
# The stride represents how much we need to increase the pointer to advance 1 row
@@ -118,7 +115,7 @@ def softmax(x):
y = torch.empty_like(x)
# Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o
# f the input matrix
softmax_kernel[(n_rows,)](
softmax_kernel[(n_rows, )](
y,
x,
x.stride(0),
@@ -158,9 +155,7 @@ assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[
128 * i for i in range(2, 100)
], # different possible values for `x_name`
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=[
'triton',
@@ -176,8 +171,7 @@ assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
)
)
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
quantiles = [0.5, 0.2, 0.8]

View File

@@ -165,6 +165,7 @@ import pytest
# provided configs
@triton.autotune(
configs=[
<<<<<<< HEAD
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
@@ -179,6 +180,24 @@ import pytest
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, num_warps=8, num_stages=0),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3}, num_warps=4, num_stages=0),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8}, num_warps=4, num_stages=0),
=======
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
num_warps=2),
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
],
key=['M', 'N', 'K'],
)
@@ -187,6 +206,7 @@ import pytest
})
@triton.jit
def matmul_kernel(
<<<<<<< HEAD
# Pointers to matrices
a_ptr, b_ptr, c_ptr,
# Matrix dimensions
@@ -202,6 +222,22 @@ def matmul_kernel(
EVEN_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
ACTIVATION: tl.constexpr,
=======
# Pointers to matrices
a_ptr, b_ptr, c_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
ACTIVATION: tl.constexpr #
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
@@ -300,16 +336,14 @@ def matmul(a, b, activation=""):
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
)
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
ACTIVATION=activation
a, b, c, #
M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
ACTIVATION=activation #
)
return c
@@ -363,6 +397,7 @@ verbose = False
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['M', 'N', 'K'], # Argument names to use as an x-axis for the plot
<<<<<<< HEAD
x_vals=[
(1024, 1024, 1024),
(2048, 2048, 2048),
@@ -370,6 +405,9 @@ verbose = False
(8192, 8192, 8192),
(9728, 8192, 65536)
], # Different possible values for `x_name`
=======
x_vals=[128 * i for i in range(2, 33)], # Different possible values for `x_name`
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
line_arg='provider', # Argument name whose value corresponds to a different line in the plot
# Possible values for `line_arg`
line_vals=['rocblas', 'triton'],
@@ -380,8 +418,7 @@ verbose = False
ylabel="TFLOPS", # Label name for the y-axis
plot_name="matmul-performance", # Name for the plot, used also as a file name for saving the plot.
args={},
)
)
))
def benchmark(M, N, K, provider):
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16)

View File

@@ -32,7 +32,6 @@ In doing so, you will learn about:
#
# Let's first take a look at the baseline implementation.
import tabulate
import torch
@@ -66,22 +65,22 @@ def dropout(x, x_keep, p):
output = torch.empty_like(x)
assert x.is_contiguous()
n_elements = x.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
_dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
return output
# Input tensor
x = torch.randn(size=(10,)).cuda()
x = torch.randn(size=(10, )).cuda()
# Dropout mask
p = 0.5
x_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda()
x_keep = (torch.rand(size=(10, )) > p).to(torch.int32).cuda()
#
output = dropout(x, x_keep=x_keep, p=p)
print(tabulate.tabulate([
["input"] + x.tolist(),
["keep mask"] + x_keep.tolist(),
["output"] + output.tolist()
["output"] + output.tolist(),
]))
# %%
@@ -134,23 +133,24 @@ def seeded_dropout(x, p, seed):
output = torch.empty_like(x)
assert x.is_contiguous()
n_elements = x.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
_seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)
return output
x = torch.randn(size=(10,)).cuda()
x = torch.randn(size=(10, )).cuda()
# Compare this to the baseline - dropout mask is never instantiated!
output = seeded_dropout(x, p=0.5, seed=123)
output2 = seeded_dropout(x, p=0.5, seed=123)
output3 = seeded_dropout(x, p=0.5, seed=512)
print(tabulate.tabulate([
["input"] + x.tolist(),
["output (seed = 123)"] + output.tolist(),
["output (seed = 123)"] + output2.tolist(),
["output (seed = 512)"] + output3.tolist()
]))
print(
tabulate.tabulate([
["input"] + x.tolist(),
["output (seed = 123)"] + output.tolist(),
["output (seed = 123)"] + output2.tolist(),
["output (seed = 512)"] + output3.tolist(),
]))
# %%
# Et Voilà! We have a triton kernel that applies the same dropout mask provided the seed is the same!

View File

@@ -126,24 +126,22 @@ def _layer_norm_fwd_fused(
# In Stage 2, the buffers are further reduced to compute the final :math:`\nabla_{w}` and :math:`\nabla_{b}`.
# In the following implementation, Stage 1 is implemented by the function :code:`_layer_norm_bwd_dx_fused` and Stage 2 is implemented by the function :code:`_layer_norm_bwd_dwdb`.
@triton.jit
def _layer_norm_bwd_dx_fused(
DX, # pointer to the input gradient
DY, # pointer to the output gradient
DW, # pointer to the partial sum of weights gradient
DB, # pointer to the partial sum of biases gradient
X, # pointer to the input
W, # pointer to the weights
B, # pointer to the biases
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
Lock, # pointer to the lock
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
eps, # epsilon to avoid division by zero
GROUP_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr
):
def _layer_norm_bwd_dx_fused(DX, # pointer to the input gradient
DY, # pointer to the output gradient
DW, # pointer to the partial sum of weights gradient
DB, # pointer to the partial sum of biases gradient
X, # pointer to the input
W, # pointer to the weights
B, # pointer to the biases
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
Lock, # pointer to the lock
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
eps, # epsilon to avoid division by zero
GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
# Map the program id to the elements of X, DX, and DY it should compute.
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE_N)
@@ -192,16 +190,13 @@ def _layer_norm_bwd_dx_fused(
@triton.jit
def _layer_norm_bwd_dwdb(
DW, # pointer to the partial sum of weights gradient
DB, # pointer to the partial sum of biases gradient
FINAL_DW, # pointer to the weights gradient
FINAL_DB, # pointer to the biases gradient
M, # GROUP_SIZE_M
N, # number of columns
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr
):
def _layer_norm_bwd_dwdb(DW, # pointer to the partial sum of weights gradient
DB, # pointer to the partial sum of biases gradient
FINAL_DW, # pointer to the weights gradient
FINAL_DB, # pointer to the biases gradient
M, # GROUP_SIZE_M
N, # number of columns
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
# Map the program id to the elements of DW and DB it should compute.
pid = tl.program_id(0)
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
@@ -258,9 +253,10 @@ class LayerNorm(torch.autograd.Function):
else:
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
# enqueue kernel
_layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd,
x_arg.stride(0), N, eps,
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1)
_layer_norm_fwd_fused[(M, )]( #
x_arg, y, weight, bias, mean, rstd, #
x_arg.stride(0), N, eps, #
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1)
ctx.save_for_backward(x, weight, bias, mean, rstd)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
@@ -280,23 +276,25 @@ class LayerNorm(torch.autograd.Function):
locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda')
_dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
_db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
dw = torch.empty((w.shape[0], ), dtype=w.dtype, device=w.device)
db = torch.empty((w.shape[0], ), dtype=w.dtype, device=w.device)
dx = torch.empty_like(dy)
# enqueue kernel using forward pass heuristics
# also compute partial sums for DW and DB
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
_layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks,
x_arg.stride(0), N, ctx.eps,
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
GROUP_SIZE_M=GROUP_SIZE_M,
num_warps=ctx.num_warps)
_layer_norm_bwd_dx_fused[(M, )]( #
dx, dy, _dw, _db, x, w, b, m, v, locks, #
x_arg.stride(0), N, ctx.eps, #
BLOCK_SIZE_N=ctx.BLOCK_SIZE, #
GROUP_SIZE_M=GROUP_SIZE_M, #
num_warps=ctx.num_warps)
grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]
# accumulate partial sums in separate kernel
_layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N,
BLOCK_SIZE_M=32,
BLOCK_SIZE_N=128, num_ctas=1)
_layer_norm_bwd_dwdb[grid](
_dw, _db, dw, db, GROUP_SIZE_M, N, #
BLOCK_SIZE_M=32, #
BLOCK_SIZE_N=128, num_ctas=1)
return dx, None, dw, db, None
@@ -340,10 +338,16 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []),
styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
ylabel='GB/s',
<<<<<<< HEAD
plot_name='layer-norm-forward',
args={'M': 4096, 'dtype': torch.float16, 'mode': 'forward'}
)
)
=======
plot_name='layer-norm-backward',
args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'},
))
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='cuda'):
# create data
x_shape = (M, N)
@@ -356,24 +360,34 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='c
quantiles = [0.5, 0.2, 0.8]
# utility functions
if provider == 'triton':
def y_fwd(): return layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704
if provider == 'torch':
def y_fwd(): return torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704
if provider == 'apex':
apex_layer_norm = apex.normalization.FusedLayerNorm(
w_shape).to(x.device).to(x.dtype)
def y_fwd(): return apex_layer_norm(x) # noqa: F811, E704
def y_fwd():
return layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704
if provider == 'torch':
def y_fwd():
return torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704
if provider == 'apex':
apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)
def y_fwd():
return apex_layer_norm(x) # noqa: F811, E704
# forward pass
if mode == 'forward':
gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500)
# backward pass
if mode == 'backward':
def gbps(ms): return 3 * x.numel() * x.element_size() / ms * 1e-6 # noqa: F811, E704
def gbps(ms):
return 3 * x.numel() * x.element_size() / ms * 1e-6 # noqa: F811, E704
y = y_fwd()
ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),
quantiles=quantiles, grad_to_none=[x], rep=500)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), quantiles=quantiles,
grad_to_none=[x], rep=500)
return gbps(ms), gbps(max_ms), gbps(min_ms)

View File

@@ -29,6 +29,7 @@ if TORCH_HAS_FP8E5FNUZ:
TORCH_HAS_FP8 = True
@triton.jit
<<<<<<< HEAD
def _attn_fwd_inner(
acc, l_i, m_i, q,
K_block_ptr, V_block_ptr,
@@ -42,6 +43,14 @@ def _attn_fwd_inner(
N_CTX,
pre_load_v: tl.constexpr,
):
=======
def _attn_fwd_inner(acc, l_i, m_i, q, #
K_block_ptr, V_block_ptr, #
start_m, qk_scale, #
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, #
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, #
N_CTX: tl.constexpr):
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
# range of values handled by this stage
if STAGE == 1:
lo, hi = 0, start_m * BLOCK_M
@@ -83,6 +92,7 @@ def _attn_fwd_inner(
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
return acc, l_i, m_i
# We don't run auto-tuning everytime to keep the tutorial fast. Uncommenting
# the code below and commenting out the equivalent parameters is convenient for
# re-tuning.
@@ -99,6 +109,7 @@ def _attn_fwd_inner(
@triton.jit
<<<<<<< HEAD
def _attn_fwd(
Q, K, V, sm_scale, M, Out,
stride_qz, stride_qh, stride_qm, stride_qk,
@@ -113,6 +124,20 @@ def _attn_fwd(
BLOCK_N: tl.constexpr,
pre_load_v: tl.constexpr,
):
=======
def _attn_fwd(Q, K, V, sm_scale, M, Out, #
stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vk, stride_vn, #
stride_oz, stride_oh, stride_om, stride_on, #
Z, H, #
N_CTX: tl.constexpr, #
BLOCK_M: tl.constexpr, #
BLOCK_DMODEL: tl.constexpr, #
BLOCK_N: tl.constexpr, #
STAGE: tl.constexpr #
):
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
qvk_offset = off_hz * stride_qh
@@ -168,6 +193,7 @@ def _attn_fwd(
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
if STAGE & 1:
<<<<<<< HEAD
acc, l_i, m_i = _attn_fwd_inner(
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
start_m,
@@ -175,11 +201,19 @@ def _attn_fwd(
4 - STAGE, offs_m, offs_n,
N_CTX, pre_load_v,
)
=======
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #
start_m, qk_scale, #
BLOCK_M, BLOCK_DMODEL, BLOCK_N, #
4 - STAGE, offs_m, offs_n, N_CTX #
)
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
# stage 2: on-band
if STAGE & 2:
# barrier makes it easier for compielr to schedule the
# two loops independently
tl.debug_barrier()
<<<<<<< HEAD
acc, l_i, m_i = _attn_fwd_inner(
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
start_m,
@@ -187,6 +221,13 @@ def _attn_fwd(
2, offs_m, offs_n,
N_CTX, pre_load_v,
)
=======
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #
start_m, qk_scale, #
BLOCK_M, BLOCK_DMODEL, BLOCK_N, #
2, offs_m, offs_n, N_CTX #
)
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
# epilogue
# write back m
acc = acc / l_i[:, None]
@@ -197,8 +238,14 @@ def _attn_fwd(
@triton.jit
def _attn_bwd_preprocess(O, DO, #
<<<<<<< HEAD
NewDO, Delta, #
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, #
=======
Delta, #
Z, H, N_CTX, #
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr #
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD)
@@ -212,6 +259,7 @@ def _attn_bwd_preprocess(O, DO, #
@triton.jit
<<<<<<< HEAD
def _bwd_kernel_dk_dv(
Q, K, V, sm_scale, Out, DO,
DK, DV,
@@ -422,11 +470,242 @@ def _bwd_kernel_dq(
order=(1, 0)
)
tl.store(DQ_block_ptr, (dq * sm_scale).to(tl.float16))
=======
def _attn_bwd_dkdv(dk, dv, #
Q, k, v, sm_scale, #
DO, #
M, D, #
# shared by Q/K/V/DO.
stride_tok, stride_d, #
H, N_CTX, BLOCK_M1: tl.constexpr, #
BLOCK_N1: tl.constexpr, #
BLOCK_DMODEL: tl.constexpr, #
# Filled in by the wrapper.
start_n, start_m, num_steps, #
MASK: tl.constexpr):
offs_m = start_m + tl.arange(0, BLOCK_M1)
offs_n = start_n + tl.arange(0, BLOCK_N1)
offs_k = tl.arange(0, BLOCK_DMODEL)
qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
curr_m = start_m
step_m = BLOCK_M1
for blk_idx in range(num_steps):
qT = tl.load(qT_ptrs)
# Load m before computing qk to reduce pipeline stall.
offs_m = curr_m + tl.arange(0, BLOCK_M1)
m = tl.load(M + offs_m)
qkT = tl.dot(k, qT)
pT = tl.math.exp2(qkT - m[None, :])
# Autoregressive masking.
if MASK:
mask = (offs_m[None, :] >= offs_n[:, None])
pT = tl.where(mask, pT, 0.0)
do = tl.load(do_ptrs)
# Compute dV.
ppT = pT
ppT = ppT.to(tl.float16)
dv += tl.dot(ppT, do)
# D (= delta) is pre-divided by ds_scale.
Di = tl.load(D + offs_m)
# Compute dP and dS.
dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
dsT = pT * (dpT - Di[None, :])
dsT = dsT.to(tl.float16)
dk += tl.dot(dsT, tl.trans(qT))
# Increment pointers.
curr_m += step_m
qT_ptrs += step_m * stride_tok
do_ptrs += step_m * stride_tok
return dk, dv
# the main inner-loop logic for computing dQ
@triton.jit
def _attn_bwd_dq(dq, q, K, V, #
do, m, D,
# shared by Q/K/V/DO.
stride_tok, stride_d, #
H, N_CTX, #
BLOCK_M2: tl.constexpr, #
BLOCK_N2: tl.constexpr, #
BLOCK_DMODEL: tl.constexpr,
# Filled in by the wrapper.
start_m, start_n, num_steps, #
MASK: tl.constexpr):
offs_m = start_m + tl.arange(0, BLOCK_M2)
offs_n = start_n + tl.arange(0, BLOCK_N2)
offs_k = tl.arange(0, BLOCK_DMODEL)
kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
# D (= delta) is pre-divided by ds_scale.
Di = tl.load(D + offs_m)
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
curr_n = start_n
step_n = BLOCK_N2
for blk_idx in range(num_steps):
kT = tl.load(kT_ptrs)
vT = tl.load(vT_ptrs)
qk = tl.dot(q, kT)
p = tl.math.exp2(qk - m)
# Autoregressive masking.
if MASK:
offs_n = curr_n + tl.arange(0, BLOCK_N2)
mask = (offs_m[:, None] >= offs_n[None, :])
p = tl.where(mask, p, 0.0)
# Compute dP and dS.
dp = tl.dot(do, vT).to(tl.float32)
ds = p * (dp - Di[:, None])
ds = ds.to(tl.float16)
# Compute dQ.
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
dq += tl.dot(ds, tl.trans(kT))
# Increment pointers.
curr_n += step_n
kT_ptrs += step_n * stride_tok
vT_ptrs += step_n * stride_tok
return dq
@triton.jit
def _attn_bwd(Q, K, V, sm_scale, #
DO, #
DQ, DK, DV, #
M, D,
# shared by Q/K/V/DO.
stride_z, stride_h, stride_tok, stride_d, #
H, N_CTX, #
BLOCK_M1: tl.constexpr, #
BLOCK_N1: tl.constexpr, #
BLOCK_M2: tl.constexpr, #
BLOCK_N2: tl.constexpr, #
BLK_SLICE_FACTOR: tl.constexpr, #
BLOCK_DMODEL: tl.constexpr):
LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
bhid = tl.program_id(2)
off_chz = (bhid * N_CTX).to(tl.int64)
adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
pid = tl.program_id(0)
# offset pointers for batch/head
Q += adj
K += adj
V += adj
DO += adj
DQ += adj
DK += adj
DV += adj
M += off_chz
D += off_chz
# load scales
offs_k = tl.arange(0, BLOCK_DMODEL)
start_n = pid * BLOCK_N1
start_m = start_n
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
offs_n = start_n + tl.arange(0, BLOCK_N1)
dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
# load K and V: they stay in SRAM throughout the inner loop.
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
num_steps = BLOCK_N1 // MASK_BLOCK_M1
dk, dv = _attn_bwd_dkdv(dk, dv, #
Q, k, v, sm_scale, #
DO, #
M, D, #
stride_tok, stride_d, #
H, N_CTX, #
MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, #
start_n, start_m, num_steps, #
MASK=True #
)
start_m += num_steps * MASK_BLOCK_M1
num_steps = (N_CTX - start_m) // BLOCK_M1
# Compute dK and dV for non-masked blocks.
dk, dv = _attn_bwd_dkdv( #
dk, dv, #
Q, k, v, sm_scale, #
DO, #
M, D, #
stride_tok, stride_d, #
H, N_CTX, #
BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, #
start_n, start_m, num_steps, #
MASK=False #
)
dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
tl.store(dv_ptrs, dv)
# Write back dK.
dk *= sm_scale
dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
tl.store(dk_ptrs, dk)
# THIS BLOCK DOES DQ:
start_m = pid * BLOCK_M2
end_n = start_m + BLOCK_M2
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
offs_m = start_m + tl.arange(0, BLOCK_M2)
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
m = tl.load(M + offs_m)
m = m[:, None]
# Compute dQ for masked (diagonal) blocks.
# NOTE: This code scans each row of QK^T backward (from right to left,
# but inside each call to _attn_bwd_dq, from left to right), but that's
# not due to anything important. I just wanted to reuse the loop
# structure for dK & dV above as much as possible.
num_steps = BLOCK_M2 // MASK_BLOCK_N2
dq = _attn_bwd_dq(dq, q, K, V, #
do, m, D, #
stride_tok, stride_d, #
H, N_CTX, #
BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL, #
start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, #
MASK=True #
)
end_n -= num_steps * MASK_BLOCK_N2
# stage 2
num_steps = end_n // BLOCK_N2
dq = _attn_bwd_dq(dq, q, K, V, #
do, m, D, #
stride_tok, stride_d, #
H, N_CTX, #
BLOCK_M2, BLOCK_N2, BLOCK_DMODEL, #
start_m, end_n - num_steps * BLOCK_N2, num_steps, #
MASK=False #
)
# Write back dQ.
dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
dq *= LN2
tl.store(dq_ptrs, dq)
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
empty = torch.empty(128, device="cuda")
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False):
# shape constraints
@@ -453,6 +732,7 @@ class _attention(torch.autograd.Function):
M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
_attn_fwd[grid](
<<<<<<< HEAD
q, k, v, sm_scale, M, o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
@@ -462,6 +742,21 @@ class _attention(torch.autograd.Function):
N_CTX=q.shape[2],
BLOCK_DMODEL=Lk,
STAGE=stage,
=======
q, k, v, sm_scale, M, o, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
q.shape[0], q.shape[1], #
N_CTX=q.shape[2], #
BLOCK_M=BLOCK_M, #
BLOCK_N=BLOCK_N, #
BLOCK_DMODEL=Lk, #
STAGE=stage, #
num_warps=num_warps, #
num_stages=num_stages #
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
)
## restore the grid for bwd kernel
@@ -493,6 +788,7 @@ class _attention(torch.autograd.Function):
dk = torch.empty_like(k)
dv = torch.empty_like(v)
BATCH, N_HEAD, N_CTX = q.shape[:3]
<<<<<<< HEAD
delta = torch.empty_like(L)
do_scaled = torch.empty_like(do)
# Figure out what BLOCK size fwd used and adjust num_blocks accordingly.
@@ -506,6 +802,39 @@ class _attention(torch.autograd.Function):
o, do, #
do_scaled, delta, #
BLOCK_M=block_scale * BLOCK, D_HEAD=ctx.BLOCK_DMODEL, #
=======
PRE_BLOCK = 128
NUM_WARPS, NUM_STAGES = 4, 1
if torch.cuda.get_device_capability()[0] == 9:
NUM_STAGES = 5
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
BLK_SLICE_FACTOR = 2
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
arg_k = k
arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
PRE_BLOCK = 128
assert N_CTX % PRE_BLOCK == 0
pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
delta = torch.empty_like(M)
_attn_bwd_preprocess[pre_grid](
o, do, #
delta, #
BATCH, N_HEAD, N_CTX, #
BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL #
)
grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
_attn_bwd[grid](
q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, #
M, delta, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
N_HEAD, N_CTX, #
BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, #
BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, #
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
BLOCK_DMODEL=ctx.BLOCK_DMODEL, #
num_warps=NUM_WARPS, #
num_stages=NUM_STAGES #
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
)
if not ctx.split_kernel:
_bwd_kernel[(ctx.grid[1],)](
@@ -599,11 +928,17 @@ def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
])
def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
<<<<<<< HEAD
causal = True
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
=======
q = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
k = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
v = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
sm_scale = 0.5
split_kernel = True
dout = torch.randn_like(q)
@@ -672,17 +1007,26 @@ for mode in ['fwd', 'bwd']:
ylabel='ms',
plot_name=f'fused-attention-{mode}-d{D_HEAD}-causal={causal}',
args={
<<<<<<< HEAD
'D_HEAD': D_HEAD,
'dtype': torch.float16,
'mode': mode,
'causal': causal})
)
=======
"H": N_HEADS,
"BATCH": BATCH,
"D_HEAD": D_HEAD,
"dtype": torch.float16,
"mode": mode,
"causal": causal,
},
))
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
@triton.testing.perf_report(configs)
def bench_flash_attention(
BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"
):
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"):
assert mode in ["fwd", "bwd"]
warmup = 25
rep = 100
@@ -706,9 +1050,7 @@ def bench_flash_attention(
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
if provider == "flash":
qkv = torch.randn(
(BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True
)
qkv = torch.randn((BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
fn = lambda: flash_attn_func(qkv, causal=causal)
if mode == "bwd":
o = fn()

View File

@@ -22,10 +22,10 @@ import triton.language as tl
@triton.jit
def asin_kernel(
x_ptr,
y_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
x_ptr,
y_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
@@ -35,12 +35,12 @@ def asin_kernel(
x = tl.math.asin(x)
tl.store(y_ptr + offsets, x, mask=mask)
# %%
# Using the default libdevice library path
# -----------------------------------------
# We can use the default libdevice library path encoded in `triton/language/math.py`
torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='cuda')
@@ -48,14 +48,12 @@ output_triton = torch.zeros(size, device='cuda')
output_torch = torch.asin(x)
assert x.is_cuda and output_triton.is_cuda
n_elements = output_torch.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024)
print(output_torch)
print(output_triton)
print(
f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}'
)
print(f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}')
# %%
# Customize the libdevice library path
@@ -67,7 +65,5 @@ asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024,
extern_libs={'libdevice': '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'})
print(output_torch)
print(output_triton)
print(
f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}'
)
print(f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}')

View File

@@ -98,14 +98,22 @@ import triton.language as tl
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
num_warps=2),
],
key=['M', 'N', 'K'],
)
@@ -118,13 +126,11 @@ def matmul_kernel_with_block_pointers(
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am, stride_ak,
stride_bk, stride_bn,
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr
):
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
@@ -196,16 +202,13 @@ def matmul(a, b):
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
)
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
matmul_kernel_with_block_pointers[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
)
a, b, c, #
M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1))
return c

View File

@@ -40,23 +40,24 @@ if torch.cuda.get_device_capability()[0] < 9:
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7,
num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4, num_ctas=2),
# triton.Config({'BLOCK_SIZE_M': 512, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4, num_ctas=4),
],
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, z_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_zm, stride_zn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr,
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr
):
def matmul_kernel(a_ptr, b_ptr, z_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_zm, stride_zn, #
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, #
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, #
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr #
):
pid = tl.program_id(axis=0)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
@@ -70,9 +71,11 @@ def matmul_kernel(
block_offset_n = pid_n * BLOCK_SIZE_N
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(block_offset_m, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), order=(A_ORDER_0, A_ORDER_1))
offsets=(block_offset_m, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),
order=(A_ORDER_0, A_ORDER_1))
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, block_offset_n), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), order=(B_ORDER_0, B_ORDER_1))
offsets=(0, block_offset_n), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),
order=(B_ORDER_0, B_ORDER_1))
z = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
offs_m = block_offset_m + tl.arange(0, BLOCK_SIZE_M)
@@ -101,15 +104,17 @@ def matmul(a, b, a_order, b_order):
z = torch.empty((M, N), device=a.device, dtype=torch.float16)
def grid(META):
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
matmul_kernel[grid](a_ptr=a, b_ptr=b, z_ptr=z,
M=M, N=N, K=K,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_zm=z.stride(0), stride_zn=z.stride(1),
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1],
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1]
)
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
matmul_kernel[grid](
a_ptr=a, b_ptr=b, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1] #
)
return z
@@ -160,14 +165,12 @@ def test_matmul():
# label name for the lines
line_names=["cuBLAS", "Triton"],
# line styles
styles=[('green', '-'), ('green', '--'),
('blue', '-'), ('blue', '--')],
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
ylabel="TFLOPS", # label name for the y-axis
plot_name="matmul-performance",
# name for the plot. Used also as a file name for saving the plot.
args={},
)
)
))
def benchmark(M, N, K, TRANS_A, TRANS_B, provider):
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
@@ -185,14 +188,15 @@ def benchmark(M, N, K, TRANS_A, TRANS_B, provider):
quantiles = [0.5, 0.2, 0.8]
if provider == 'cublas':
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch.matmul(a, b), rep=100, quantiles=quantiles, fast_flush=False)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), rep=100, quantiles=quantiles,
fast_flush=False)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: matmul(a, b, a_order, b_order), rep=100, quantiles=quantiles, fast_flush=False)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, a_order, b_order), rep=100,
quantiles=quantiles, fast_flush=False)
def perf(ms):
return 2 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)

View File

@@ -40,21 +40,21 @@ if torch.cuda.get_device_capability()[0] < 9:
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7,
num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 512, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4, num_ctas=4),
],
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
def matmul_kernel(a_ptr, b_ptr, c_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr #
):
pid = tl.program_id(axis=0)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
@@ -67,20 +67,10 @@ def matmul_kernel(
block_offset_m = pid_m * BLOCK_SIZE_M
block_offset_n = pid_n * BLOCK_SIZE_N
a_tile_ptr = tl.make_block_ptr(
base=a_ptr, shape=(
M, K), strides=(
stride_am, stride_ak), offsets=(
block_offset_m, 0), block_shape=(
BLOCK_SIZE_M, BLOCK_SIZE_K), order=(
1, 0))
b_tile_ptr = tl.make_block_ptr(
base=b_ptr, shape=(
K, N), strides=(
stride_bk, stride_bn), offsets=(
0, block_offset_n), block_shape=(
BLOCK_SIZE_K, BLOCK_SIZE_N), order=(
0, 1))
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(block_offset_m, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), order=(1, 0))
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, block_offset_n), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), order=(0, 1))
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K):
@@ -91,7 +81,8 @@ def matmul_kernel(
b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_SIZE_K, 0])
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0))
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N),
order=(1, 0))
tl.store(c_block_ptr, accumulator)
@@ -101,20 +92,19 @@ def matmul(a, b):
assert a.shape[1] == b.shape[0], "incompatible dimensions"
M, K = a.shape
K, N = b.shape
assert (
K % 32 == 0
), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K"
assert (K % 32 == 0), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K"
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
def grid(META):
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
M=M, N=N, K=K,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1))
matmul_kernel[grid](
a_ptr=a, b_ptr=b, c_ptr=c, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_cm=c.stride(0), stride_cn=c.stride(1))
return c
@@ -126,12 +116,7 @@ c = torch.nn.functional.normalize(c)
golden = torch.nn.functional.normalize(torch.matmul(a, b))
torch.set_printoptions(profile="full")
assert_close(
c,
golden,
rtol=1e-2,
atol=1e-3,
check_dtype=False)
assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
@triton.testing.perf_report(
@@ -143,7 +128,7 @@ assert_close(
[2048, 1024, 1024],
[2048, 2048, 2048],
[2048, 4096, 4096],
[2048, 8192, 8192]
[2048, 8192, 8192],
], # different possible values for `x_name`
line_arg='provider',
# argument name whose value corresponds to a different line in the plot
@@ -152,27 +137,26 @@ assert_close(
# label name for the lines
line_names=["cuBLAS", "Triton"],
# line styles
styles=[('green', '-'), ('green', '--'),
('blue', '-'), ('blue', '--')],
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
ylabel="TFLOPS", # label name for the y-axis
plot_name="matmul-performance",
# name for the plot. Used also as a file name for saving the plot.
args={},
)
)
))
def benchmark(M, N, K, provider):
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
quantiles = [0.5, 0.2, 0.8]
if provider == 'cublas':
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch.matmul(a, b), rep=100, quantiles=quantiles, fast_flush=False)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), rep=100, quantiles=quantiles,
fast_flush=False)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: matmul(a, b), rep=100, quantiles=quantiles, fast_flush=False)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), rep=100, quantiles=quantiles,
fast_flush=False)
def perf(ms):
return 2 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)

View File

@@ -1,3 +1,13 @@
<<<<<<< HEAD
=======
"""
Group GEMM
============================
This group gemm kernel launches a fixed number of CTA to compute a group
of gemms. The scheduling is static and we do it on device.
"""
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining
@@ -28,6 +38,7 @@ import triton.language as tl
# of gemms. The scheduling is static and we do it on device
@triton.autotune(
configs=[
<<<<<<< HEAD
triton.Config(
{
'BLOCK_SIZE_M': 128,
@@ -111,6 +122,32 @@ import triton.language as tl
num_stages = 0,
num_warps = 2,
),
=======
triton.Config({
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'NUM_SM': 84,
}),
triton.Config({
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'NUM_SM': 128,
}),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'NUM_SM': 84,
}),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'NUM_SM': 128,
}),
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
],
key=['SUM_M', 'SUM_N', 'SUM_K'],
)
@@ -149,9 +186,7 @@ def grouped_matmul_kernel(
num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)
num_tiles = num_m_tiles * num_n_tiles
# iterate through the tiles in the current gemm problem
while (
tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles
):
while (tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles):
# pick up a tile from the current gemm problem
k = gk
lda = tl.load(g_lds + g * 3)
@@ -171,9 +206,7 @@ def grouped_matmul_kernel(
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :]
b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :]
accumulator = tl.zeros(
(BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32
)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):
# hint to Triton compiler to do proper loop pipelining
tl.multiple_of(a_ptrs, [16, 16])
@@ -224,7 +257,7 @@ def group_gemm_fn(group_A, group_B):
group_C.append(C)
A_addrs.append(A.data_ptr())
B_addrs.append(B.data_ptr())
C_addrs .append(C.data_ptr())
C_addrs.append(C.data_ptr())
g_sizes += [M, N, K]
SUM_M += M
SUM_N += N
@@ -235,14 +268,10 @@ def group_gemm_fn(group_A, group_B):
d_a_ptrs = torch.tensor(A_addrs, device=device)
d_b_ptrs = torch.tensor(B_addrs, device=device)
d_c_ptrs = torch.tensor(C_addrs, device=device)
d_g_sizes = torch.tensor(
g_sizes, dtype=torch.int32, device=device
)
d_g_lds = torch.tensor(
g_lds, dtype=torch.int32, device=device
)
d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=device)
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=device)
# we use a fixed number of CTA, and it's auto-tunable
grid = lambda META: (META['NUM_SM'],)
grid = lambda META: (META['NUM_SM'], )
grouped_matmul_kernel[grid](
d_a_ptrs,
d_b_ptrs,
@@ -283,8 +312,13 @@ for i in range(group_size):
# only launch the kernel, no tensor preparation here to remove all overhead
<<<<<<< HEAD
def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size, sum_m, sum_n, sum_k):
grid = lambda META: (META['NUM_SM'],)
=======
def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size):
grid = lambda META: (META['NUM_SM'], )
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
grouped_matmul_kernel[grid](
a_ptrs,
b_ptrs,
@@ -307,7 +341,7 @@ def torch_perf_fn(group_A, group_B):
triton.testing.Benchmark(
# argument names to use as an x-axis for the plot
x_names=['N'],
x_vals=[2 ** i for i in range(7, 11)], # different possible values for `x_name`
x_vals=[2**i for i in range(7, 11)], # different possible values for `x_name`
line_arg='provider',
# argument name whose value corresponds to a different line in the plot
# possible values for `line_arg``
@@ -320,8 +354,7 @@ def torch_perf_fn(group_A, group_B):
plot_name="group-gemm-performance",
# name for the plot. Used also as a file name for saving the plot.
args={},
)
)
))
def benchmark(N, provider):
group_size = 4
group_A = []
@@ -341,7 +374,7 @@ def benchmark(N, provider):
group_C.append(C)
A_addrs.append(A.data_ptr())
B_addrs.append(B.data_ptr())
C_addrs .append(C.data_ptr())
C_addrs.append(C.data_ptr())
g_sizes += [N, N, N]
g_lds += [N, N, N]
@@ -355,7 +388,12 @@ def benchmark(N, provider):
if provider == 'cublas':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_perf_fn(group_A, group_B), quantiles=quantiles)
if provider == 'triton':
<<<<<<< HEAD
ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_perf_fn(d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size, group_size*N, group_size*N, group_size*N), quantiles=quantiles)
=======
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: triton_perf_fn(d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size), quantiles=quantiles)
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
return ms, max_ms, min_ms