mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
Merge commit 'cb3d79a185e40c9d8a579bea07747a8a8d157d52' into ifu-231117
Conflicts: lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp lib/Dialect/TritonGPU/IR/Dialect.cpp python/setup.py python/test/unit/language/assert_helper.py python/test/unit/operators/test_flash_attention.py python/test/unit/runtime/test_subproc.py python/triton/compiler/compiler.py python/triton/language/semantic.py python/triton/runtime/autotuner.py python/triton/runtime/jit.py python/tutorials/03-matrix-multiplication.py python/tutorials/05-layer-norm.py python/tutorials/06-fused-attention.py python/tutorials/11-grouped-gemm.py test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
@@ -4,8 +4,8 @@ import triton.language as tl
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, stride_xm,
|
||||
Z, stride_zn,
|
||||
def kernel(X, stride_xm, #
|
||||
Z, stride_zn, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
|
||||
off_m = tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, BLOCK_N)
|
||||
|
||||
@@ -10,4 +10,4 @@ def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr):
|
||||
|
||||
|
||||
X = torch.randn(1, device="cuda")
|
||||
pgm = kernel[(1,)](X, 1, 1, BLOCK=1024)
|
||||
pgm = kernel[(1, )](X, 1, 1, BLOCK=1024)
|
||||
|
||||
@@ -2,7 +2,14 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=40.8.0", "wheel", "cmake>=3.18", "ninja>=1.11.1"]
|
||||
|
||||
# We're incrementally switching from autopep8 to ruff.
|
||||
[tool.autopep8]
|
||||
aggressive = 1
|
||||
ignore = "E501,E701,E731,W690"
|
||||
ignore = "E501,E701,E731,W690,W503"
|
||||
max_line_length = 88
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
|
||||
[tool.ruff.lint]
|
||||
ignore = ["E501", "E701", "E731", "E741"]
|
||||
|
||||
@@ -55,6 +55,7 @@ class Package(NamedTuple):
|
||||
lib_flag: str
|
||||
syspath_var_name: str
|
||||
|
||||
|
||||
# pybind11
|
||||
|
||||
|
||||
@@ -63,6 +64,7 @@ def get_pybind11_package_info():
|
||||
url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.11.1.tar.gz"
|
||||
return Package("pybind11", name, url, "PYBIND11_INCLUDE_DIR", "", "PYBIND11_SYSPATH")
|
||||
|
||||
|
||||
# llvm
|
||||
|
||||
|
||||
@@ -74,6 +76,8 @@ def get_llvm_package_info():
|
||||
arch = 'arm64'
|
||||
if system == "Darwin":
|
||||
arch = platform.machine()
|
||||
if arch == "x86_64":
|
||||
arch = "x64"
|
||||
system_suffix = f"macos-{arch}"
|
||||
elif system == "Linux":
|
||||
# TODO: arm64
|
||||
@@ -84,7 +88,7 @@ def get_llvm_package_info():
|
||||
return Package("llvm", "LLVM-C.lib", "", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
|
||||
# use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
|
||||
# release_suffix = "assert" if use_assert_enabled_llvm else "release"
|
||||
rev = "b1115f8c"
|
||||
rev = "49af6502"
|
||||
name = f"llvm-{rev}-{system_suffix}"
|
||||
url = f"https://tritonlang.blob.core.windows.net/llvm-builds/{name}.tar.gz"
|
||||
return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
|
||||
@@ -119,10 +123,13 @@ def get_thirdparty_packages(triton_cache_path):
|
||||
thirdparty_cmake_args.append(f"-D{p.lib_flag}={package_dir}/lib")
|
||||
return thirdparty_cmake_args
|
||||
|
||||
|
||||
# ---- package data ---
|
||||
|
||||
|
||||
def download_and_copy(src_path, version, url_func):
|
||||
def download_and_copy(src_path, variable, version, url_func):
|
||||
if variable in os.environ:
|
||||
return
|
||||
base_dir = os.path.dirname(__file__)
|
||||
arch = platform.machine()
|
||||
if arch == "x86_64":
|
||||
@@ -148,7 +155,7 @@ def download_and_copy(src_path, version, url_func):
|
||||
src_path = os.path.join(temp_dir, src_path)
|
||||
os.makedirs(os.path.split(dst_path)[0], exist_ok=True)
|
||||
shutil.copy(src_path, dst_path)
|
||||
return dst_suffix
|
||||
|
||||
|
||||
# ---- cmake extension ----
|
||||
|
||||
@@ -167,18 +174,21 @@ def get_cmake_dir():
|
||||
|
||||
|
||||
class CMakeClean(clean):
|
||||
|
||||
def initialize_options(self):
|
||||
clean.initialize_options(self)
|
||||
self.build_temp = get_cmake_dir()
|
||||
|
||||
|
||||
class CMakeBuildPy(build_py):
|
||||
|
||||
def run(self) -> None:
|
||||
self.run_command('build_ext')
|
||||
return super().run()
|
||||
|
||||
|
||||
class CMakeExtension(Extension):
|
||||
|
||||
def __init__(self, name, path, sourcedir=""):
|
||||
Extension.__init__(self, name, sources=[])
|
||||
self.sourcedir = os.path.abspath(sourcedir)
|
||||
@@ -201,7 +211,8 @@ class CMakeBuild(build_ext):
|
||||
try:
|
||||
out = subprocess.check_output(["cmake", "--version"])
|
||||
except OSError:
|
||||
raise RuntimeError("CMake must be installed to build the following extensions: " + ", ".join(e.name for e in self.extensions))
|
||||
raise RuntimeError("CMake must be installed to build the following extensions: " +
|
||||
", ".join(e.name for e in self.extensions))
|
||||
|
||||
match = re.search(r"version\s*(?P<major>\d+)\.(?P<minor>\d+)([\d.]+)?", out.decode())
|
||||
cmake_major, cmake_minor = int(match.group("major")), int(match.group("minor"))
|
||||
@@ -228,8 +239,10 @@ class CMakeBuild(build_ext):
|
||||
# python directories
|
||||
python_include_dir = sysconfig.get_path("platinclude")
|
||||
cmake_args = [
|
||||
"-G", "Ninja", # Ninja is much faster than make
|
||||
"-DCMAKE_MAKE_PROGRAM=" + ninja_dir, # Pass explicit path to ninja otherwise cmake may cache a temporary path
|
||||
"-G",
|
||||
"Ninja", # Ninja is much faster than make
|
||||
"-DCMAKE_MAKE_PROGRAM=" +
|
||||
ninja_dir, # Pass explicit path to ninja otherwise cmake may cache a temporary path
|
||||
"-DCMAKE_EXPORT_COMPILE_COMMANDS=ON",
|
||||
"-DLLVM_ENABLE_WERROR=ON",
|
||||
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
|
||||
@@ -263,12 +276,28 @@ class CMakeBuild(build_ext):
|
||||
build_args += ['-j' + max_jobs]
|
||||
|
||||
if check_env_flag("TRITON_BUILD_WITH_CLANG_LLD"):
|
||||
cmake_args += ["-DCMAKE_C_COMPILER=clang",
|
||||
"-DCMAKE_CXX_COMPILER=clang++",
|
||||
"-DCMAKE_LINKER=lld",
|
||||
"-DCMAKE_EXE_LINKER_FLAGS=-fuse-ld=lld",
|
||||
"-DCMAKE_MODULE_LINKER_FLAGS=-fuse-ld=lld",
|
||||
"-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld"]
|
||||
cmake_args += [
|
||||
"-DCMAKE_C_COMPILER=clang",
|
||||
"-DCMAKE_CXX_COMPILER=clang++",
|
||||
"-DCMAKE_LINKER=lld",
|
||||
"-DCMAKE_EXE_LINKER_FLAGS=-fuse-ld=lld",
|
||||
"-DCMAKE_MODULE_LINKER_FLAGS=-fuse-ld=lld",
|
||||
"-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld",
|
||||
]
|
||||
|
||||
# Note that asan doesn't work with binaries that use the GPU, so this is
|
||||
# only useful for tools like triton-opt that don't run code on the GPU.
|
||||
#
|
||||
# I tried and gave up getting msan to work. It seems that libstdc++'s
|
||||
# std::string does not play nicely with clang's msan (I didn't try
|
||||
# gcc's). I was unable to configure clang to ignore the error, and I
|
||||
# also wasn't able to get libc++ to work, but that doesn't mean it's
|
||||
# impossible. :)
|
||||
if check_env_flag("TRITON_BUILD_WITH_ASAN"):
|
||||
cmake_args += [
|
||||
"-DCMAKE_C_FLAGS=-fsanitize=address",
|
||||
"-DCMAKE_CXX_FLAGS=-fsanitize=address",
|
||||
]
|
||||
|
||||
if check_env_flag("TRITON_BUILD_WITH_CCACHE"):
|
||||
cmake_args += [
|
||||
@@ -282,9 +311,27 @@ class CMakeBuild(build_ext):
|
||||
subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir)
|
||||
|
||||
|
||||
download_and_copy(src_path='bin/ptxas', version='12.1.105', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2")
|
||||
download_and_copy(src_path='bin/cuobjdump', version='12.1.111', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2")
|
||||
download_and_copy(src_path='bin/nvdisasm', version='12.1.105', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2")
|
||||
download_and_copy(
|
||||
src_path="bin/ptxas",
|
||||
variable="TRITON_PTXAS_PATH",
|
||||
version="12.1.105",
|
||||
url_func=lambda arch, version:
|
||||
f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2",
|
||||
)
|
||||
download_and_copy(
|
||||
src_path="bin/cuobjdump",
|
||||
variable="TRITON_CUOBJDUMP_PATH",
|
||||
version="12.1.111",
|
||||
url_func=lambda arch, version:
|
||||
f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2",
|
||||
)
|
||||
download_and_copy(
|
||||
src_path="bin/nvdisasm",
|
||||
variable="TRITON_NVDISASM_PATH",
|
||||
version="12.1.105",
|
||||
url_func=lambda arch, version:
|
||||
f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2",
|
||||
)
|
||||
|
||||
setup(
|
||||
name="triton",
|
||||
@@ -307,10 +354,14 @@ setup(
|
||||
"triton/third_party",
|
||||
"triton/tools",
|
||||
],
|
||||
<<<<<<< HEAD
|
||||
long_description_content_type="text/markdown",
|
||||
install_requires=[
|
||||
"filelock"
|
||||
],
|
||||
=======
|
||||
install_requires=["filelock"],
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
include_package_data=True,
|
||||
ext_modules=[CMakeExtension("triton", "triton/_C/")],
|
||||
cmdclass={"build_ext": CMakeBuild, "build_py": CMakeBuildPy, "clean": CMakeClean},
|
||||
|
||||
@@ -229,6 +229,12 @@ void init_triton_ir(py::module &&m) {
|
||||
.value("RELAXED", mlir::triton::MemSemantic::RELAXED)
|
||||
.export_values();
|
||||
|
||||
py::enum_<mlir::triton::MemSyncScope>(m, "MEM_SYNC_SCOPE", py::module_local())
|
||||
.value("GPU", mlir::triton::MemSyncScope::GPU)
|
||||
.value("CTA", mlir::triton::MemSyncScope::CTA)
|
||||
.value("SYSTEM", mlir::triton::MemSyncScope::SYSTEM)
|
||||
.export_values();
|
||||
|
||||
py::enum_<mlir::triton::EvictionPolicy>(m, "EVICTION_POLICY",
|
||||
py::module_local())
|
||||
.value("NORMAL", mlir::triton::EvictionPolicy::NORMAL)
|
||||
@@ -1527,7 +1533,8 @@ void init_triton_ir(py::module &&m) {
|
||||
// // atomic
|
||||
.def("create_atomic_cas",
|
||||
[](TritonOpBuilder &self, mlir::Value &ptr, mlir::Value &cmp,
|
||||
mlir::Value &val, mlir::triton::MemSemantic sem) -> mlir::Value {
|
||||
mlir::Value &val, mlir::triton::MemSemantic sem,
|
||||
mlir::triton::MemSyncScope scope) -> mlir::Value {
|
||||
mlir::Type dstType;
|
||||
if (auto srcTensorType =
|
||||
ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
|
||||
@@ -1542,12 +1549,13 @@ void init_triton_ir(py::module &&m) {
|
||||
dstType = ptrType.getPointeeType();
|
||||
}
|
||||
return self.create<mlir::triton::AtomicCASOp>(dstType, ptr, cmp,
|
||||
val, sem);
|
||||
val, sem, scope);
|
||||
})
|
||||
.def("create_atomic_rmw",
|
||||
[](TritonOpBuilder &self, mlir::triton::RMWOp rmwOp,
|
||||
mlir::Value &ptr, mlir::Value &val, mlir::Value &mask,
|
||||
mlir::triton::MemSemantic sem) -> mlir::Value {
|
||||
mlir::triton::MemSemantic sem,
|
||||
mlir::triton::MemSyncScope scope) -> mlir::Value {
|
||||
mlir::Type dstType;
|
||||
if (auto srcTensorType =
|
||||
ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
|
||||
@@ -1561,8 +1569,8 @@ void init_triton_ir(py::module &&m) {
|
||||
.cast<mlir::triton::PointerType>();
|
||||
dstType = ptrType.getPointeeType();
|
||||
}
|
||||
return self.create<mlir::triton::AtomicRMWOp>(dstType, rmwOp, ptr,
|
||||
val, mask, sem);
|
||||
return self.create<mlir::triton::AtomicRMWOp>(
|
||||
dstType, rmwOp, ptr, val, mask, sem, scope);
|
||||
})
|
||||
// External
|
||||
.def("create_extern_elementwise",
|
||||
@@ -1764,6 +1772,10 @@ void init_triton_ir(py::module &&m) {
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUCoalescePass());
|
||||
})
|
||||
.def("add_tritongpu_optimize_thread_locality_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUOptimizeThreadLocalityPass());
|
||||
})
|
||||
.def("add_symbol_dce_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createSymbolDCEPass());
|
||||
|
||||
@@ -13,12 +13,11 @@ import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.common.backend import BaseBackend, register_backend
|
||||
from triton.common.backend import (BaseBackend, compute_core_version_key, register_backend)
|
||||
from triton.common.build import quiet
|
||||
from triton.compiler.make_launcher import make_so_cache_key
|
||||
from triton.runtime.cache import get_cache_manager
|
||||
from triton.runtime.driver import DriverBase
|
||||
from triton.runtime.jit import version_key
|
||||
|
||||
|
||||
def build_for_backend(name, src, srcdir):
|
||||
@@ -81,6 +80,7 @@ def build_for_backend(name, src, srcdir):
|
||||
|
||||
|
||||
class ExtensionUtils:
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
cls.instance = super(ExtensionUtils, cls).__new__(cls)
|
||||
@@ -110,6 +110,7 @@ class ExtensionUtils:
|
||||
|
||||
|
||||
class ExtensionDriver(DriverBase):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
cls.instance = super(ExtensionDriver, cls).__new__(cls)
|
||||
@@ -125,6 +126,7 @@ class ExtensionBackend(BaseBackend):
|
||||
def __init__(self, device_type: str) -> None:
|
||||
super(ExtensionBackend, self).__init__(device_type)
|
||||
self.driver = ExtensionDriver()
|
||||
self.version_key = None
|
||||
|
||||
def add_stages(self, arch, extern_libs, stages):
|
||||
filter_in_stages = ["ast", "ttir", "ttgir"]
|
||||
@@ -163,9 +165,14 @@ class ExtensionBackend(BaseBackend):
|
||||
def get_architecture_descriptor(self, **kwargs):
|
||||
return ""
|
||||
|
||||
def get_version_key(self):
|
||||
if self.version_key is None:
|
||||
self.version_key = compute_core_version_key()
|
||||
return self.version_key
|
||||
|
||||
def make_launcher_stub(self, name, signature, constants):
|
||||
# name of files that are cached
|
||||
so_cache_key = make_so_cache_key(version_key(), signature, constants)
|
||||
so_cache_key = make_so_cache_key(self.get_version_key(), signature, constants)
|
||||
so_cache_manager = get_cache_manager(so_cache_key)
|
||||
so_name = f"{name}.so"
|
||||
# retrieve stub from cache if it exists
|
||||
@@ -250,13 +257,13 @@ def test_dummy_backend():
|
||||
|
||||
inp = torch.randn(10)
|
||||
out = torch.randn(10)
|
||||
kernel[(10,)](inp, out, 10, XBLOCK=16)
|
||||
kernel[(10, )](inp, out, 10, XBLOCK=16)
|
||||
spec = importlib.util.spec_from_file_location("__triton_launcher", ExtensionBackend.stub_so_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
launch_counter = getattr(mod, "launch_counter")
|
||||
|
||||
for _ in range(100):
|
||||
kernel[(10,)](inp, out, 10, XBLOCK=16)
|
||||
kernel[(10, )](inp, out, 10, XBLOCK=16)
|
||||
|
||||
assert launch_counter() > 0
|
||||
|
||||
@@ -4,9 +4,7 @@ import pytest
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--backend", action="store", default="", help="Codegen backend"
|
||||
)
|
||||
parser.addoption("--backend", action="store", default="", help="Codegen backend")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -24,10 +24,10 @@ def test_xpu_backend(cmdopt):
|
||||
|
||||
if has_ipex:
|
||||
for _ in range(1000):
|
||||
x = torch.randn((65536,), device="xpu", dtype=torch.float32)
|
||||
y = torch.randn((65536,), device="xpu", dtype=torch.float32)
|
||||
z = torch.zeros((65536,), device="xpu", dtype=torch.float32)
|
||||
kernel[(65536,)](x, y, z, num_warps=32)
|
||||
x = torch.randn((65536, ), device="xpu", dtype=torch.float32)
|
||||
y = torch.randn((65536, ), device="xpu", dtype=torch.float32)
|
||||
z = torch.zeros((65536, ), device="xpu", dtype=torch.float32)
|
||||
kernel[(65536, )](x, y, z, num_warps=32)
|
||||
assert torch.all(x + y == z)
|
||||
else:
|
||||
return
|
||||
|
||||
100
python/test/regression/test_cast_matmul.py
Normal file
100
python/test/regression/test_cast_matmul.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
issue: https://github.com/openai/triton/issues/2523
|
||||
fused type convert and matmul, base on triton matmul, the different with matmul:
|
||||
1. force C's dtype=dot_out_dtype to ["float16", "float32"]
|
||||
2. accept A and B with dtype=["float32", "float64"]
|
||||
|
||||
"""
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton.language as tl
|
||||
from triton import cdiv, jit
|
||||
|
||||
input_dtypes = ["float32", "float64"]
|
||||
out_dtypes = ["float16", "float32"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, K, N, w_dtype, x_dtype, out_dtype",
|
||||
[(M, K, N, w, x, o) #
|
||||
for (M, K, N) in [(128, 128, 128), (1280, 768, 1024)] #
|
||||
for w in input_dtypes
|
||||
for x in input_dtypes #
|
||||
for o in out_dtypes])
|
||||
def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype):
|
||||
if x_dtype == w_dtype:
|
||||
pytest.skip("skip same dtype")
|
||||
device = torch.cuda.current_device()
|
||||
x_dtype = getattr(torch, x_dtype)
|
||||
w_dtype = getattr(torch, w_dtype)
|
||||
a = torch.randn((M, K), device=device, dtype=x_dtype)
|
||||
b = torch.randn((K, N), device=device, dtype=w_dtype)
|
||||
torch_dtype = getattr(torch, out_dtype)
|
||||
triton_dtype = getattr(tl, out_dtype) # <- here force dot_out_dtype
|
||||
out_torch = torch.matmul(a.to(torch_dtype), b.to(torch_dtype))
|
||||
out_triton = torch.empty((M, N), device=device, dtype=torch_dtype)
|
||||
|
||||
allow_tf32 = True
|
||||
# launch kernel
|
||||
BLOCK_M, BLOCK_N, BLOCK_K = 16, 16, 32
|
||||
grid = ((cdiv(M, BLOCK_M) * cdiv(N, BLOCK_N)), 1)
|
||||
|
||||
@jit
|
||||
def matmul_kernel(A, B, C, M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
dot_out_dtype: tl.constexpr, #
|
||||
allow_tf32: tl.constexpr, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, #
|
||||
BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr):
|
||||
# matrix multiplication
|
||||
pid = tl.program_id(0)
|
||||
grid_m = tl.cdiv(M, BLOCK_M)
|
||||
grid_n = tl.cdiv(N, BLOCK_N)
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
# do matrix multiplication
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||
rk = tl.arange(0, BLOCK_K)
|
||||
# pointers
|
||||
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||
k_remaining = K - k * BLOCK_K
|
||||
_0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
|
||||
a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
|
||||
b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
|
||||
a = a.to(C.dtype.element_ty)
|
||||
b = b.to(C.dtype.element_ty)
|
||||
acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
|
||||
A += BLOCK_K * stride_ak
|
||||
B += BLOCK_K * stride_bk
|
||||
acc = acc.to(C.dtype.element_ty)
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
tl.store(C, acc, mask=mask)
|
||||
|
||||
matmul_kernel[grid](
|
||||
a, b, out_triton, M, N, K, #
|
||||
a.stride(0), a.stride(1), #
|
||||
b.stride(0), b.stride(1), #
|
||||
out_triton.stride(0), out_triton.stride(1), dot_out_dtype=triton_dtype, #
|
||||
allow_tf32=allow_tf32, #
|
||||
GROUP_M=8, #
|
||||
BLOCK_M=BLOCK_M, #
|
||||
BLOCK_N=BLOCK_N, #
|
||||
BLOCK_K=BLOCK_K)
|
||||
|
||||
torch.testing.assert_close(out_torch, out_triton, atol=0.3, rtol=0.01)
|
||||
@@ -14,18 +14,14 @@ def test_chained_matmul():
|
||||
return torch.einsum('MN,NK->MK', intermediate, c)
|
||||
|
||||
@triton.jit
|
||||
def chained_matmul_kernel(
|
||||
A, # shape: (m, k)
|
||||
B, # shape: (n, k)
|
||||
C, # shape: (n, k)
|
||||
out, # shape: (m, k)
|
||||
m, n, k: tl.constexpr,
|
||||
block_m: tl.constexpr,
|
||||
block_n: tl.constexpr,
|
||||
block_k: tl.constexpr):
|
||||
def chained_matmul_kernel(A, # shape: (m, k)
|
||||
B, # shape: (n, k)
|
||||
C, # shape: (n, k)
|
||||
out, # shape: (m, k)
|
||||
m, n, k: tl.constexpr, #
|
||||
block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr):
|
||||
|
||||
tl.static_assert(block_k == k,
|
||||
f"expected block_k == k but got {block_k} != {k}")
|
||||
tl.static_assert(block_k == k, f"expected block_k == k but got {block_k} != {k}")
|
||||
|
||||
block_ix = tl.program_id(0)
|
||||
a_tile = (block_ix * block_m + tl.arange(0, block_m))[:, None] * block_k \
|
||||
@@ -55,35 +51,33 @@ def test_chained_matmul():
|
||||
m, n, k = 32, 64, 128
|
||||
block_m, block_n, block_k = 16, 32, k
|
||||
|
||||
grid = (triton.cdiv(m, block_m),)
|
||||
a = torch.randint(low=0, high=2, size=(m, k), dtype=torch.float16,
|
||||
device='cuda')
|
||||
b = torch.randint(low=0, high=2, size=(n, k), dtype=torch.float16,
|
||||
device='cuda')
|
||||
grid = (triton.cdiv(m, block_m), )
|
||||
a = torch.randint(low=0, high=2, size=(m, k), dtype=torch.float16, device='cuda')
|
||||
b = torch.randint(low=0, high=2, size=(n, k), dtype=torch.float16, device='cuda')
|
||||
c = torch.randint_like(b, low=0, high=2)
|
||||
triton_result = torch.zeros_like(a)
|
||||
|
||||
torch_result = chained_matmul_reference(a, b, c)
|
||||
chained_matmul_kernel[grid](a, b, c, triton_result, m, n, k,
|
||||
block_m=block_m, block_n=block_n,
|
||||
block_k=block_k)
|
||||
chained_matmul_kernel[grid](
|
||||
a, b, c, triton_result, m, n, k, #
|
||||
block_m=block_m, block_n=block_n, block_k=block_k)
|
||||
|
||||
assert (torch_result == triton_result).all()
|
||||
|
||||
|
||||
def test_vecmat():
|
||||
|
||||
@triton.jit
|
||||
def batched_vecmat(
|
||||
# inputs
|
||||
A, # shape: [dim_m, dim_k]
|
||||
B, # shape: [dim_m, dim_n, dim_k]
|
||||
# dimensions
|
||||
# inputs
|
||||
A, # shape: [dim_m, dim_k]
|
||||
B, # shape: [dim_m, dim_n, dim_k]
|
||||
# dimensions
|
||||
dim_m, dim_n, dim_k,
|
||||
# outputs
|
||||
output,
|
||||
# block information
|
||||
block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr
|
||||
):
|
||||
# outputs
|
||||
output,
|
||||
# block information
|
||||
block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr):
|
||||
m_index = tl.program_id(0)
|
||||
n_index = tl.program_id(1)
|
||||
# Output tile
|
||||
@@ -125,9 +119,10 @@ def test_vecmat():
|
||||
|
||||
grid = (M // block_m, N // block_n)
|
||||
|
||||
batched_vecmat[grid](A_tri, B_tri, M, N, K, C_tri,
|
||||
block_m=block_m, block_n=block_n, block_k=block_k,
|
||||
num_warps=4, num_stages=1)
|
||||
batched_vecmat[grid](
|
||||
A_tri, B_tri, M, N, K, C_tri, #
|
||||
block_m=block_m, block_n=block_n, block_k=block_k, #
|
||||
num_warps=4, num_stages=1)
|
||||
|
||||
A_expanded = A[:, np.newaxis, :]
|
||||
A_broadcasted = np.broadcast_to(A_expanded, (M, N, K))
|
||||
@@ -137,18 +132,18 @@ def test_vecmat():
|
||||
np.testing.assert_allclose(C_ref, C_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("type", ["pre_load", "post_load", "post_pre_mixed", "post_load_two_iters", "post_load_three_iters"])
|
||||
@pytest.mark.parametrize("type",
|
||||
["pre_load", "post_load", "post_pre_mixed", "post_load_two_iters", "post_load_three_iters"])
|
||||
def test_iv_dependent_matmul(type):
|
||||
|
||||
@triton.jit
|
||||
def kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
type: tl.constexpr
|
||||
):
|
||||
def kernel(a_ptr, b_ptr, c_ptr, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
|
||||
type: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
pid_m = pid // num_pid_n
|
||||
@@ -216,15 +211,16 @@ def test_iv_dependent_matmul(type):
|
||||
b = torch.rand((K, N), device='cuda')
|
||||
|
||||
torch_output = torch.mm(a, b)
|
||||
triton_output = torch.empty_like(
|
||||
torch_output, device=torch_output.device)
|
||||
triton_output = torch.empty_like(torch_output, device=torch_output.device)
|
||||
|
||||
def grid(META):
|
||||
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
|
||||
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
|
||||
|
||||
num_stages = 4 if type == "post_load_three_iters" else 3
|
||||
kernel[grid](a, b, triton_output, M, N, K, a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1), triton_output.stride(0), triton_output.stride(1),
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
type=type, num_stages=num_stages)
|
||||
kernel[grid](
|
||||
a, b, triton_output, M, N, K, #
|
||||
a.stride(0), a.stride(1), b.stride(0), b.stride(1), #
|
||||
triton_output.stride(0), triton_output.stride(1), #
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, type=type, #
|
||||
num_stages=num_stages)
|
||||
torch.testing.assert_close(torch_output, triton_output, rtol=1e-2, atol=1e-2)
|
||||
|
||||
@@ -26,7 +26,6 @@ sm_clocks = {'v100': 1350, 'a100': 1350}
|
||||
mem_clocks = {'v100': 877, 'a100': 1215}
|
||||
|
||||
matmul_data = {
|
||||
# NOTE:
|
||||
'a100': {
|
||||
# square
|
||||
(512, 512, 512): {'float16': 0.108, 'float32': 0.097, 'int8': 0.05},
|
||||
@@ -49,10 +48,9 @@ matmul_data = {
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M, N, K, dtype_str',
|
||||
[(M, N, K, dtype_str)
|
||||
for M, N, K in matmul_data[DEVICE_NAME].keys()
|
||||
for dtype_str in ['float16']])
|
||||
@pytest.mark.parametrize('M, N, K, dtype_str', [(M, N, K, dtype_str)
|
||||
for M, N, K in matmul_data[DEVICE_NAME].keys()
|
||||
for dtype_str in ['float16']])
|
||||
def test_matmul(M, N, K, dtype_str):
|
||||
stream = torch.cuda.Stream()
|
||||
torch.cuda.set_stream(stream)
|
||||
@@ -86,8 +84,7 @@ def test_matmul(M, N, K, dtype_str):
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _add(x_ptr, y_ptr, output_ptr, n_elements,
|
||||
BLOCK_SIZE: tl.constexpr):
|
||||
def _add(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
@@ -136,36 +133,36 @@ def test_elementwise(N, dtype_str):
|
||||
print_perf(ms, cur_gpu_util, ref_gpu_util)
|
||||
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.02, rtol=0.01)
|
||||
|
||||
|
||||
#######################
|
||||
# Flash-Attention
|
||||
#######################
|
||||
|
||||
|
||||
flash_attention_data = {
|
||||
"a100": {
|
||||
(4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.542,
|
||||
(4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.471,
|
||||
(4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.155,
|
||||
(4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.203,
|
||||
(4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.202,
|
||||
(4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.108,
|
||||
(4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.232,
|
||||
(4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.231,
|
||||
(4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.138,
|
||||
(4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.306,
|
||||
(4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.266,
|
||||
(4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.098,
|
||||
(4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.134,
|
||||
(4, 48, 4096, 64, True, False, 'backward', 'bfloat16'): 0.135,
|
||||
(4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.066,
|
||||
(4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.092,
|
||||
(4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.541,
|
||||
(4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.471,
|
||||
(4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.150,
|
||||
(4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.263,
|
||||
(4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.291,
|
||||
(4, 48, 4096, 64, False, True, 'backward', 'bfloat16'): 0.255,
|
||||
(4, 48, 1024, 16, False, True, 'backward', 'float32'): 0.144,
|
||||
(4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.306,
|
||||
(4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.266,
|
||||
(4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.098,
|
||||
(4, 48, 4096, 64, False, False, 'backward', 'float16'): 0.159,
|
||||
(4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.136,
|
||||
(4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.159,
|
||||
(4, 48, 1024, 16, False, False, 'backward', 'float32'): 0.088,
|
||||
}
|
||||
}
|
||||
@@ -221,8 +218,7 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str):
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _sum(x_ptr, y_ptr, output_ptr, n_elements,
|
||||
BLOCK_SIZE: tl.constexpr):
|
||||
def _sum(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
@@ -260,8 +256,8 @@ def test_reductions(N, dtype_str):
|
||||
y = torch.randn_like(z)
|
||||
else:
|
||||
info = torch.iinfo(dtype)
|
||||
x = torch.randint(info.min, info.max, (N,), dtype=dtype, device='cuda')
|
||||
y = torch.randint(info.min, info.max, (N,), dtype=dtype, device='cuda')
|
||||
x = torch.randint(info.min, info.max, (N, ), dtype=dtype, device='cuda')
|
||||
y = torch.randint(info.min, info.max, (N, ), dtype=dtype, device='cuda')
|
||||
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
|
||||
fn = lambda: _sum[grid](x, y, z, N, BLOCK_SIZE=1024)
|
||||
ms = triton.testing.do_bench_cudagraph(fn)
|
||||
|
||||
@@ -9,6 +9,7 @@ import yaml
|
||||
|
||||
|
||||
class ComparisonResult:
|
||||
|
||||
def __init__(self, name: str, numComparisons: int, diffs: List[str] = None, errors: List[str] = None):
|
||||
self.name = name
|
||||
self.numComparisons = numComparisons
|
||||
@@ -142,7 +143,8 @@ def doFilesMatch(path1: str, path2: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def compareMatchingFiles(name: str, nameToHashes1: Dict[str, List[str]], nameToHashes2: Dict[str, List[str]], args) -> ComparisonResult:
|
||||
def compareMatchingFiles(name: str, nameToHashes1: Dict[str, List[str]], nameToHashes2: Dict[str, List[str]],
|
||||
args) -> ComparisonResult:
|
||||
"""
|
||||
Compare files with the given name in all hashes in both paths
|
||||
Return the first mismatching files as a tuple (file1, file2), otherwise, return an empty tuple
|
||||
|
||||
@@ -18,7 +18,6 @@
|
||||
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
"""
|
||||
Fused Attention
|
||||
===============
|
||||
@@ -35,18 +34,15 @@ import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
L, M,
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, H, N_CTX, D0,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
def _fwd_kernel(Q, K, V, sm_scale, #
|
||||
L, M, #
|
||||
Out, #
|
||||
stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vk, stride_vn, #
|
||||
stride_oz, stride_oh, stride_om, stride_on, #
|
||||
Z, H, N_CTX, D0, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
|
||||
@@ -61,31 +57,38 @@ def _fwd_kernel(
|
||||
|
||||
stride_qh_2d = stride_qh // stride_qm // stride_qk
|
||||
|
||||
q_tile_ptr = tl.make_block_ptr(base=Q,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(
|
||||
off_hz * stride_qh_2d + start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
k_tile_ptr = tl.make_block_ptr(base=K,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_kn, stride_kk),
|
||||
offsets=(off_hz * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
v_tile_ptr = tl.make_block_ptr(base=V,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_vk, stride_vn),
|
||||
offsets=(off_hz * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
out_tile_ptr = tl.make_block_ptr(base=Out,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
q_tile_ptr = tl.make_block_ptr(
|
||||
base=Q,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
k_tile_ptr = tl.make_block_ptr(
|
||||
base=K,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_kn, stride_kk),
|
||||
offsets=(off_hz * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
v_tile_ptr = tl.make_block_ptr(
|
||||
base=V,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_vk, stride_vn),
|
||||
offsets=(off_hz * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
out_tile_ptr = tl.make_block_ptr(
|
||||
base=Out,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
# load q: it will stay in SRAM throughout
|
||||
q = tl.load(q_tile_ptr)
|
||||
|
||||
@@ -96,8 +99,7 @@ def _fwd_kernel(
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, tl.trans(k))
|
||||
qk *= sm_scale
|
||||
qk = tl.where(offs_m[:, None] >= (
|
||||
start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
# compute new m
|
||||
m_curr = tl.maximum(tl.max(qk, 1), m_prev)
|
||||
# correct old l
|
||||
@@ -133,11 +135,9 @@ def _fwd_kernel(
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_preprocess(
|
||||
Out, DO, L,
|
||||
NewDO, Delta,
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
||||
):
|
||||
def _bwd_preprocess(Out, DO, L, #
|
||||
NewDO, Delta, #
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr):
|
||||
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, D_HEAD)
|
||||
# load
|
||||
@@ -153,19 +153,14 @@ def _bwd_preprocess(
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_kernel(
|
||||
Q, K, V, sm_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
L, M,
|
||||
D,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
Z, H, N_CTX, D0,
|
||||
num_block,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
def _bwd_kernel(Q, K, V, sm_scale, Out, DO, #
|
||||
DQ, DK, DV, #
|
||||
L, M, #
|
||||
D, stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vk, stride_vn, #
|
||||
Z, H, N_CTX, D0, #
|
||||
num_block, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr):
|
||||
off_hz = tl.program_id(0)
|
||||
off_z = off_hz // H
|
||||
off_h = off_hz % H
|
||||
@@ -173,55 +168,62 @@ def _bwd_kernel(
|
||||
stride_qz_2d = stride_qz // stride_qm // stride_qk
|
||||
stride_qh_2d = stride_qh // stride_qm // stride_qk
|
||||
|
||||
q_tile_ptr = tl.make_block_ptr(base=Q,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(
|
||||
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
k_tile_ptr = tl.make_block_ptr(base=K,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_kn, stride_kk),
|
||||
offsets=(
|
||||
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
v_tile_ptr = tl.make_block_ptr(base=V,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_vk, stride_vn),
|
||||
offsets=(
|
||||
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
do_tile_ptr = tl.make_block_ptr(base=DO,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(
|
||||
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
dq_tile_ptr = tl.make_block_ptr(base=DQ,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(
|
||||
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
dk_tile_ptr = tl.make_block_ptr(base=DK,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(
|
||||
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
dv_tile_ptr = tl.make_block_ptr(base=DV,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(
|
||||
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
q_tile_ptr = tl.make_block_ptr(
|
||||
base=Q,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
k_tile_ptr = tl.make_block_ptr(
|
||||
base=K,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_kn, stride_kk),
|
||||
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
v_tile_ptr = tl.make_block_ptr(
|
||||
base=V,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_vk, stride_vn),
|
||||
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
do_tile_ptr = tl.make_block_ptr(
|
||||
base=DO,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
dq_tile_ptr = tl.make_block_ptr(
|
||||
base=DQ,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
dk_tile_ptr = tl.make_block_ptr(
|
||||
base=DK,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
dv_tile_ptr = tl.make_block_ptr(
|
||||
base=DV,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
# offset pointers for batch/head
|
||||
DQ += off_z * stride_qz + off_h * stride_qh
|
||||
for start_n in range(0, num_block):
|
||||
@@ -250,8 +252,7 @@ def _bwd_kernel(
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
# NOTE: `do` is pre-divided by `l`; no normalization here
|
||||
qk = tl.dot(q, tl.trans(k))
|
||||
qk = tl.where(offs_m_curr[:, None] >= (
|
||||
offs_n[None, :]), qk, float("-inf"))
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
||||
m = tl.load(m_ptrs + offs_m_curr)
|
||||
p = tl.exp(qk * sm_scale - m[:, None])
|
||||
# compute dv
|
||||
@@ -301,29 +302,21 @@ class _attention(torch.autograd.Function):
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
o = torch.empty_like(q)
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
|
||||
L = torch.empty(
|
||||
(q.shape[0] * q.shape[1], q.shape[2]),
|
||||
device=q.device,
|
||||
dtype=torch.float32)
|
||||
m = torch.empty(
|
||||
(q.shape[0] * q.shape[1], q.shape[2]),
|
||||
device=q.device,
|
||||
dtype=torch.float32)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
D0 = q.shape[0] * q.shape[1] * q.shape[2]
|
||||
_fwd_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
L, m,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2], D0,
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=Lk, num_warps=num_warps,
|
||||
num_stages=2,
|
||||
)
|
||||
q, k, v, sm_scale, #
|
||||
L, m, #
|
||||
o, #
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
|
||||
q.shape[0], q.shape[1], q.shape[2], D0, #
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK, BLOCK_DMODEL=Lk, #
|
||||
num_warps=num_warps, num_stages=2)
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
ctx.grid = grid
|
||||
@@ -343,25 +336,22 @@ class _attention(torch.autograd.Function):
|
||||
delta = torch.empty_like(l)
|
||||
D0 = q.shape[0] * q.shape[1] * q.shape[2]
|
||||
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
||||
o, do, l,
|
||||
do_scaled, delta,
|
||||
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
_bwd_kernel[(ctx.grid[1],)](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dq, dk, dv,
|
||||
l, m,
|
||||
delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2], D0,
|
||||
ctx.grid[0],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
||||
num_stages=1,
|
||||
)
|
||||
o, do, l, #
|
||||
do_scaled, delta, #
|
||||
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL)
|
||||
_bwd_kernel[(ctx.grid[1], )](
|
||||
q, k, v, ctx.sm_scale, #
|
||||
o, do_scaled, #
|
||||
dq, dk, dv, #
|
||||
l, m, #
|
||||
delta, #
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
|
||||
q.shape[0], q.shape[1], q.shape[2], D0, #
|
||||
ctx.grid[0], #
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK, BLOCK_DMODEL=ctx.BLOCK_DMODEL, #
|
||||
num_warps=8, num_stages=1)
|
||||
return dq, dk, dv, None
|
||||
|
||||
|
||||
@@ -380,15 +370,9 @@ attention = _attention.apply
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires arch 9+")
|
||||
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
q = torch.empty(
|
||||
(Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(
|
||||
mean=0.1, std=0.2).requires_grad_()
|
||||
k = torch.empty(
|
||||
(Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(
|
||||
mean=0.4, std=0.2).requires_grad_()
|
||||
v = torch.empty(
|
||||
(Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(
|
||||
mean=0.3, std=0.2).requires_grad_()
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
|
||||
sm_scale = 0.2
|
||||
dout = torch.randn_like(q)
|
||||
# reference implementation
|
||||
@@ -427,22 +411,25 @@ except BaseException:
|
||||
|
||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
||||
# vary seq length for fixed head and batch=4
|
||||
configs = [triton.testing.Benchmark(
|
||||
x_names=['N_CTX'],
|
||||
x_vals=[2**i for i in range(10, 14)],
|
||||
line_arg='provider',
|
||||
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
|
||||
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
|
||||
styles=[('red', '-'), ('blue', '-')],
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
||||
args={
|
||||
'H': N_HEADS,
|
||||
'BATCH': BATCH,
|
||||
'D_HEAD': D_HEAD,
|
||||
'dtype': torch.float16,
|
||||
'mode': mode}
|
||||
) for mode in ['fwd', 'bwd']]
|
||||
configs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=['N_CTX'],
|
||||
x_vals=[2**i for i in range(10, 14)],
|
||||
line_arg='provider',
|
||||
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
|
||||
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
|
||||
styles=[('red', '-'), ('blue', '-')],
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
||||
args={
|
||||
'H': N_HEADS,
|
||||
'BATCH': BATCH,
|
||||
'D_HEAD': D_HEAD,
|
||||
'dtype': torch.float16,
|
||||
'mode': mode,
|
||||
},
|
||||
) for mode in ['fwd', 'bwd']
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
@@ -463,9 +450,8 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
|
||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
if provider == "flash":
|
||||
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
|
||||
cu_seqlens = torch.zeros(
|
||||
(BATCH + 1,), device=device, dtype=torch.int32)
|
||||
lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device)
|
||||
cu_seqlens = torch.zeros((BATCH + 1, ), device=device, dtype=torch.int32)
|
||||
cu_seqlens[1:] = lengths.cumsum(0)
|
||||
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
|
||||
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
|
||||
|
||||
@@ -32,19 +32,30 @@ import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def matmul_no_scf_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr
|
||||
):
|
||||
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
|
||||
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
|
||||
offsets=(0, 0), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
|
||||
def matmul_no_scf_kernel(a_ptr, b_ptr, c_ptr, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
|
||||
FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr #
|
||||
):
|
||||
a_block_ptr = tl.make_block_ptr(
|
||||
base=a_ptr,
|
||||
shape=(M, K),
|
||||
strides=(stride_am, stride_ak),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_K),
|
||||
order=(1, 0),
|
||||
)
|
||||
b_block_ptr = tl.make_block_ptr(
|
||||
base=b_ptr,
|
||||
shape=(K, N),
|
||||
strides=(stride_bk, stride_bn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_K, BLOCK_N),
|
||||
order=(0, 1),
|
||||
)
|
||||
a = tl.load(a_block_ptr)
|
||||
b = tl.load(b_block_ptr)
|
||||
|
||||
@@ -54,8 +65,8 @@ def matmul_no_scf_kernel(
|
||||
c = c.to(tl.float16)
|
||||
|
||||
if USE_TMA_EPILOGUE:
|
||||
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
|
||||
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
|
||||
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
|
||||
tl.store(c_block_ptr, c)
|
||||
else:
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
@@ -64,33 +75,30 @@ def matmul_no_scf_kernel(
|
||||
tl.store(c_ptrs, c)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,ENABLE_WS',
|
||||
itertools.chain(
|
||||
*[
|
||||
[
|
||||
# numCTAs = 1, no TMA multicast:
|
||||
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
# static mask, cluster 4x1
|
||||
[256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
# dynamic mask, cluster 2x2
|
||||
[128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
# small M, N
|
||||
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
] for USE_TMA_EPILOGUE in [True, False]
|
||||
for ENABLE_WS in [False, True]
|
||||
]))
|
||||
@pytest.mark.parametrize(
|
||||
'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,ENABLE_WS',
|
||||
itertools.chain(*[[
|
||||
# numCTAs = 1, no TMA multicast:
|
||||
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
# static mask, cluster 4x1
|
||||
[256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
# dynamic mask, cluster 2x2
|
||||
[128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
# small M, N
|
||||
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
] for USE_TMA_EPILOGUE in [True, False] for ENABLE_WS in [False, True]]))
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE, ENABLE_WS):
|
||||
if (TRANS_A):
|
||||
@@ -107,46 +115,41 @@ def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE
|
||||
else:
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
||||
|
||||
matmul_no_scf_kernel[(1, 1)](a_ptr=a, b_ptr=b, c_ptr=c,
|
||||
M=M, N=N, K=K,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K,
|
||||
num_warps=NUM_WARPS,
|
||||
num_ctas=NUM_CTAS,
|
||||
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"),
|
||||
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE,
|
||||
enable_warp_specialization=ENABLE_WS)
|
||||
matmul_no_scf_kernel[(1, 1)](
|
||||
a_ptr=a, b_ptr=b, c_ptr=c, #
|
||||
M=M, N=N, K=K, #
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1), #
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1), #
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1), #
|
||||
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, #
|
||||
num_warps=NUM_WARPS, #
|
||||
num_ctas=NUM_CTAS, #
|
||||
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), #
|
||||
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE, #
|
||||
enable_warp_specialization=ENABLE_WS)
|
||||
a_f32 = a.to(torch.float32)
|
||||
b_f32 = b.to(torch.float32)
|
||||
golden = torch.matmul(a_f32, b_f32)
|
||||
torch.set_printoptions(profile="full")
|
||||
assert_close(
|
||||
c,
|
||||
golden,
|
||||
rtol=1e-2,
|
||||
atol=1e-3,
|
||||
check_dtype=False)
|
||||
assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_wm, stride_wn,
|
||||
stride_zm, stride_zn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
|
||||
out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr,
|
||||
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
|
||||
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
|
||||
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr,
|
||||
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr,
|
||||
W_ORDER_0: tl.constexpr, W_ORDER_1: tl.constexpr,
|
||||
Z_ORDER_0: tl.constexpr, Z_ORDER_1: tl.constexpr
|
||||
):
|
||||
def matmul_kernel(a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_wm, stride_wn, #
|
||||
stride_zm, stride_zn, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, #
|
||||
out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr, #
|
||||
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, #
|
||||
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, #
|
||||
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, #
|
||||
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr, #
|
||||
W_ORDER_0: tl.constexpr, W_ORDER_1: tl.constexpr, #
|
||||
Z_ORDER_0: tl.constexpr, Z_ORDER_1: tl.constexpr #
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_N)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_M)
|
||||
@@ -159,13 +162,31 @@ def matmul_kernel(
|
||||
block_offset_m = pid_m * BLOCK_M
|
||||
block_offset_n = pid_n * BLOCK_N
|
||||
|
||||
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
|
||||
offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(A_ORDER_0, A_ORDER_1))
|
||||
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
|
||||
offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(B_ORDER_0, B_ORDER_1))
|
||||
a_tile_ptr = tl.make_block_ptr(
|
||||
base=a_ptr,
|
||||
shape=(M, K),
|
||||
strides=(stride_am, stride_ak),
|
||||
offsets=(block_offset_m, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_K),
|
||||
order=(A_ORDER_0, A_ORDER_1),
|
||||
)
|
||||
b_tile_ptr = tl.make_block_ptr(
|
||||
base=b_ptr,
|
||||
shape=(K, N),
|
||||
strides=(stride_bk, stride_bn),
|
||||
offsets=(0, block_offset_n),
|
||||
block_shape=(BLOCK_K, BLOCK_N),
|
||||
order=(B_ORDER_0, B_ORDER_1),
|
||||
)
|
||||
# for chain-dot, BLOCK_N must always be equal to N, and each program loads the whole W matrix
|
||||
w_tile_ptr = tl.make_block_ptr(base=w_ptr, shape=(N, N), strides=(stride_wm, stride_wn),
|
||||
offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_N), order=(W_ORDER_0, W_ORDER_1))
|
||||
w_tile_ptr = tl.make_block_ptr(
|
||||
base=w_ptr,
|
||||
shape=(N, N),
|
||||
strides=(stride_wm, stride_wn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_N),
|
||||
order=(W_ORDER_0, W_ORDER_1),
|
||||
)
|
||||
z = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
|
||||
offs_m = block_offset_m + tl.arange(0, BLOCK_M)
|
||||
@@ -204,141 +225,151 @@ def matmul_kernel(
|
||||
|
||||
if USE_TMA_STORE:
|
||||
z_block_ptr = tl.make_block_ptr(base=z_ptr, shape=(M, N), strides=(stride_zm, stride_zn),
|
||||
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(Z_ORDER_0, Z_ORDER_1))
|
||||
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N),
|
||||
order=(Z_ORDER_0, Z_ORDER_1))
|
||||
tl.store(z_block_ptr, z, boundary_check=(0, 1))
|
||||
else:
|
||||
tl.store(z_ptrs, z, mask=mask)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS',
|
||||
[
|
||||
# corner shapes
|
||||
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws)
|
||||
for shape_w_c in [
|
||||
[4096, 1, 1024, False, False, True],
|
||||
[2048, 204, 1000, True, False, True],
|
||||
[4096, 1, 1024, False, False, False],
|
||||
[2048, 204, 1000, True, False, False],
|
||||
]
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for enable_ws in [False, True]
|
||||
] + [
|
||||
# softmax epilogue
|
||||
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[64, 64, 16, 4, 1, 64, 64, 64],
|
||||
[128, 128, 64, 4, 1, None, None, None],
|
||||
[16, 16, 64, 4, 1, 16, 16, 64],
|
||||
[64, 64, 32, 8, 1, 64, 64, 64],
|
||||
[128, 128, 64, 4, 1, 128, 128, 128],
|
||||
]
|
||||
for epilogue in ['softmax']
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for trans_a in [False,]
|
||||
for trans_b in [True,]
|
||||
for trans_output in [False,]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [False, True]
|
||||
] + [
|
||||
# loop over epilogues besides of softmax
|
||||
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[64, 64, 16, 4, 1, 128, 128, 64],
|
||||
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
|
||||
# for chain-dot
|
||||
[128, 128, 64, 4, 1, None, None, None],
|
||||
[64, 64, 16, 4, 1, None, None, None],
|
||||
# small BLOCK_M and BLOCK_K
|
||||
[16, 16, 64, 4, 1, 128, 128, 64],
|
||||
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
|
||||
# repeat
|
||||
[64, 64, 32, 8, 1, 128, 256, 64],
|
||||
[64, 64, 16, 8, 2, 128, 128, 64],
|
||||
# irregular shape
|
||||
[128, 128, 64, 4, 1, 500, 200, 128],
|
||||
[128, 128, 64, 4, 2, 513, 193, 192],
|
||||
]
|
||||
for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot']
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for trans_a in [False,]
|
||||
for trans_b in [True,]
|
||||
for trans_output in [False,]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [False, True]
|
||||
if not (epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
|
||||
] + [
|
||||
# loop over tile shapes and transpose combinations
|
||||
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[64, 64, 32, 4, 1, 128, 256, 64],
|
||||
[128, 128, 16, 4, 4, 512, 256, 64],
|
||||
[128, 256, 32, 4, 8, 256, 256, 192],
|
||||
[512, 256, 32, 4, 8, 1024, 256, 192],
|
||||
# BLOCK_K >= 128
|
||||
[64, 128, 128, 4, 1, 512, 256, 256],
|
||||
[128, 128, 128, 4, 1, 256, 256, 192],
|
||||
[128, 128, 128, 4, 2, 256, 256, 192],
|
||||
# small BLOCK_M and BLOCK_K
|
||||
[16, 32, 32, 4, 1, 128, 256, 64],
|
||||
[32, 32, 16, 4, 1, 256, 256, 192],
|
||||
[16, 32, 64, 4, 4, 512, 256, 64],
|
||||
]
|
||||
for out_dtype in ['float32',]
|
||||
for use_tma_store in [False,]
|
||||
for trans_a in [False, True]
|
||||
for trans_b in [False, True]
|
||||
for trans_output in [False, True]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [False, True]
|
||||
] + [
|
||||
# loop over instr shapes & pipeline stages
|
||||
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for n in [16, 32, 64, 128, 256]
|
||||
for trans_output in [False,]
|
||||
for out_dtype in ['float32',]
|
||||
for use_tma_store in [False,]
|
||||
for num_stages in [2, 4, 5, 7]
|
||||
for enable_ws in [False, True]
|
||||
] + [
|
||||
# irregular shapes
|
||||
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[128, 128, 64, 4, 1],
|
||||
[256, 128, 64, 4, 2],
|
||||
[128, 128, 128, 4, 2],
|
||||
]
|
||||
for shape in [
|
||||
[512, 360, 1024],
|
||||
[360, 4096, 512],
|
||||
]
|
||||
for trans_output in [False,]
|
||||
for out_dtype in ['float32',]
|
||||
for use_tma_store in [False, True]
|
||||
for num_stages in [3, 4]
|
||||
for enable_ws in [False, True]
|
||||
])
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()
|
||||
[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS):
|
||||
@pytest.mark.parametrize(
|
||||
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS',
|
||||
[
|
||||
# corner shapes
|
||||
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws)
|
||||
for shape_w_c in [
|
||||
[4096, 1, 1024, False, False, True],
|
||||
[2048, 204, 1000, True, False, True],
|
||||
[4096, 1, 1024, False, False, False],
|
||||
[2048, 204, 1000, True, False, False],
|
||||
]
|
||||
for out_dtype in ['float16', 'float32'] #
|
||||
for use_tma_store in [False, True] #
|
||||
for enable_ws in [False, True]
|
||||
] + [
|
||||
# softmax epilogue
|
||||
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[64, 64, 16, 4, 1, 64, 64, 64],
|
||||
[128, 128, 64, 4, 1, None, None, None],
|
||||
[16, 16, 64, 4, 1, 16, 16, 64],
|
||||
[64, 64, 32, 8, 1, 64, 64, 64],
|
||||
[128, 128, 64, 4, 1, 128, 128, 128],
|
||||
]
|
||||
for epilogue in ['softmax']
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for trans_a in [False]
|
||||
for trans_b in [True]
|
||||
for trans_output in [False]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [False, True]
|
||||
] + [
|
||||
# loop over epilogues besides of softmax
|
||||
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[64, 64, 16, 4, 1, 128, 128, 64],
|
||||
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64]
|
||||
for num_warps in [4, 8]
|
||||
for num_ctas in [1, 2, 4]],
|
||||
# for chain-dot
|
||||
[128, 128, 64, 4, 1, None, None, None],
|
||||
[64, 64, 16, 4, 1, None, None, None],
|
||||
# small BLOCK_M and BLOCK_K
|
||||
[16, 16, 64, 4, 1, 128, 128, 64],
|
||||
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256]
|
||||
for num_warps in [4, 8]
|
||||
for num_ctas in [1, 2]],
|
||||
# repeat
|
||||
[64, 64, 32, 8, 1, 128, 256, 64],
|
||||
[64, 64, 16, 8, 2, 128, 128, 64],
|
||||
# irregular shape
|
||||
[128, 128, 64, 4, 1, 500, 200, 128],
|
||||
[128, 128, 64, 4, 2, 513, 193, 192],
|
||||
]
|
||||
for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot']
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for trans_a in [False]
|
||||
for trans_b in [True]
|
||||
for trans_output in [False]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [False, True]
|
||||
if not (epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
|
||||
] + [
|
||||
# loop over tile shapes and transpose combinations
|
||||
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[64, 64, 32, 4, 1, 128, 256, 64],
|
||||
[128, 128, 16, 4, 4, 512, 256, 64],
|
||||
[128, 256, 32, 4, 8, 256, 256, 192],
|
||||
[512, 256, 32, 4, 8, 1024, 256, 192],
|
||||
# BLOCK_K >= 128
|
||||
[64, 128, 128, 4, 1, 512, 256, 256],
|
||||
[128, 128, 128, 4, 1, 256, 256, 192],
|
||||
[128, 128, 128, 4, 2, 256, 256, 192],
|
||||
# small BLOCK_M and BLOCK_K
|
||||
[16, 32, 32, 4, 1, 128, 256, 64],
|
||||
[32, 32, 16, 4, 1, 256, 256, 192],
|
||||
[16, 32, 64, 4, 4, 512, 256, 64],
|
||||
]
|
||||
for out_dtype in ['float32']
|
||||
for use_tma_store in [False]
|
||||
for trans_a in [False, True]
|
||||
for trans_b in [False, True]
|
||||
for trans_output in [False, True]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [False, True]
|
||||
] + [
|
||||
# loop over instr shapes & pipeline stages
|
||||
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages,
|
||||
enable_ws)
|
||||
for n in [16, 32, 64, 128, 256]
|
||||
for trans_output in [False]
|
||||
for out_dtype in ['float32']
|
||||
for use_tma_store in [False]
|
||||
for num_stages in [2, 4, 5, 7]
|
||||
for enable_ws in [False, True]
|
||||
] + [
|
||||
# irregular shapes
|
||||
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[128, 128, 64, 4, 1],
|
||||
[256, 128, 64, 4, 2],
|
||||
[128, 128, 128, 4, 2],
|
||||
]
|
||||
for shape in [
|
||||
[512, 360, 1024],
|
||||
[360, 4096, 512],
|
||||
]
|
||||
for trans_output in [False]
|
||||
for out_dtype in ['float32']
|
||||
for use_tma_store in [False, True]
|
||||
for num_stages in [3, 4]
|
||||
for enable_ws in [False, True]
|
||||
])
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
|
||||
out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS):
|
||||
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
|
||||
'16-32-64-4-4-512-256-64-True-False',
|
||||
'16-32-64-4-4-512-256-64-True-True',
|
||||
'16-32-64-4-4-512-256-64-False-False',
|
||||
'16-32-64-4-4-512-256-64-False-True',
|
||||
'16-32-64-4-4-512-256-64-True-False',
|
||||
'16-32-64-4-4-512-256-64-True-True',
|
||||
'16-32-64-4-4-512-256-64-False-False',
|
||||
'16-32-64-4-4-512-256-64-False-True',
|
||||
]:
|
||||
pytest.skip('shapePerCTA[1] < 16 not supported')
|
||||
|
||||
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
|
||||
'16-32-64-4-1-256-256-256-False',
|
||||
'16-32-64-4-2-256-256-256-False',
|
||||
'16-32-64-4-2-256-256-256-True',
|
||||
'16-32-64-8-2-256-256-256-False',
|
||||
'16-32-64-8-2-256-256-256-True',
|
||||
'16-32-64-4-1-256-256-256-False',
|
||||
'16-32-64-4-2-256-256-256-False',
|
||||
'16-32-64-4-2-256-256-256-True',
|
||||
'16-32-64-8-2-256-256-256-False',
|
||||
'16-32-64-8-2-256-256-256-True',
|
||||
]:
|
||||
pytest.skip('Known legacy issue, ldmatrix can only support x4')
|
||||
enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower()
|
||||
if NUM_CTAS > 1 and enable_tma in ["on", "true", "1"]:
|
||||
pytest.skip('multi-CTA with TMA not supported in MaterializeLoadStore')
|
||||
|
||||
M = BLOCK_M if M is None else M
|
||||
N = BLOCK_N if N is None else N
|
||||
@@ -410,38 +441,38 @@ def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A,
|
||||
else:
|
||||
ref = d
|
||||
return ref
|
||||
|
||||
golden = process_epilogue(dot, bias, w, epilogue)
|
||||
|
||||
def grid(META):
|
||||
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
|
||||
pgm = matmul_kernel[grid](a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z,
|
||||
M=M, N=N, K=K,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_wm=w.stride(0), stride_wn=w.stride(1),
|
||||
stride_zm=z.stride(0), stride_zn=z.stride(1),
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8,
|
||||
out_dtype=out_dtype,
|
||||
USE_TMA_STORE=USE_TMA_STORE,
|
||||
ADD_MATRIX=epilogue == 'add-matrix',
|
||||
ADD_ROWS=epilogue == 'add-rows',
|
||||
ADD_COLS=epilogue == 'add-cols',
|
||||
DO_SOFTMAX=epilogue == 'softmax',
|
||||
CHAIN_DOT=epilogue == 'chain-dot',
|
||||
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1],
|
||||
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1],
|
||||
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1],
|
||||
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1],
|
||||
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES,
|
||||
enable_warp_specialization=ENABLE_WS)
|
||||
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
|
||||
|
||||
pgm = matmul_kernel[grid](
|
||||
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
|
||||
M=M, N=N, K=K, #
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1), #
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1), #
|
||||
stride_wm=w.stride(0), stride_wn=w.stride(1), #
|
||||
stride_zm=z.stride(0), stride_zn=z.stride(1), #
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
|
||||
out_dtype=out_dtype, #
|
||||
USE_TMA_STORE=USE_TMA_STORE, #
|
||||
ADD_MATRIX=epilogue == 'add-matrix', #
|
||||
ADD_ROWS=epilogue == 'add-rows', #
|
||||
ADD_COLS=epilogue == 'add-cols', #
|
||||
DO_SOFTMAX=epilogue == 'softmax', #
|
||||
CHAIN_DOT=epilogue == 'chain-dot', #
|
||||
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
|
||||
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
|
||||
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
|
||||
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
|
||||
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES, #
|
||||
enable_warp_specialization=ENABLE_WS)
|
||||
|
||||
torch.set_printoptions(profile="full")
|
||||
golden = torch.nn.functional.normalize(golden)
|
||||
z = torch.nn.functional.normalize(z)
|
||||
assert_close(z, golden,
|
||||
rtol=1e-2,
|
||||
atol=1e-3,
|
||||
check_dtype=False)
|
||||
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
|
||||
|
||||
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
|
||||
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
|
||||
|
||||
@@ -27,16 +27,20 @@ import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def gemm_fusion_kernel(A, B, C, E,
|
||||
M, N, K,
|
||||
stride_am, stride_ak, stride_bn, stride_bk, stride_cn, stride_ck, stride_em, stride_ek,
|
||||
def gemm_fusion_kernel(A, B, C, E, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, stride_bn, stride_bk, stride_cn, stride_ck, stride_em, stride_ek, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
a_tile_ptr = tl.make_block_ptr(base=A, shape=(M, K), strides=(stride_am, stride_ak), offsets=(pid * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_tile_ptr = tl.make_block_ptr(base=B, shape=(N, K), strides=(stride_bn, stride_bk), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_K), order=(1, 0))
|
||||
c_tile_ptr = tl.make_block_ptr(base=C, shape=(N, K), strides=(stride_cn, stride_ck), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_K), order=(1, 0))
|
||||
e_tile_ptr = tl.make_block_ptr(base=E, shape=(M, K), strides=(stride_em, stride_ek), offsets=(pid * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
a_tile_ptr = tl.make_block_ptr(base=A, shape=(M, K), strides=(stride_am, stride_ak), offsets=(pid * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_tile_ptr = tl.make_block_ptr(base=B, shape=(N, K), strides=(stride_bn, stride_bk), offsets=(0, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_K), order=(1, 0))
|
||||
c_tile_ptr = tl.make_block_ptr(base=C, shape=(N, K), strides=(stride_cn, stride_ck), offsets=(0, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_K), order=(1, 0))
|
||||
e_tile_ptr = tl.make_block_ptr(base=E, shape=(M, K), strides=(stride_em, stride_ek), offsets=(pid * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
|
||||
acc_e = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
|
||||
a = tl.load(a_tile_ptr)
|
||||
@@ -57,66 +61,70 @@ def gemm_fusion_kernel(A, B, C, E,
|
||||
def test_gemm_fusion():
|
||||
M, N, K = 4096, 4096, 64
|
||||
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 64
|
||||
A = torch.empty(
|
||||
(M, K), dtype=torch.float16, device='cuda').normal_(
|
||||
mean=0.1, std=0.2)
|
||||
B = torch.empty(
|
||||
(N, K), dtype=torch.float16, device='cuda').normal_(
|
||||
mean=0.1, std=0.2)
|
||||
C = torch.empty(
|
||||
(N, K), dtype=torch.float16, device='cuda').normal_(
|
||||
mean=0.1, std=0.2)
|
||||
A = torch.empty((M, K), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2)
|
||||
B = torch.empty((N, K), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2)
|
||||
C = torch.empty((N, K), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2)
|
||||
E = torch.empty((M, K), dtype=torch.float16, device='cuda')
|
||||
ref_out = torch.matmul(torch.matmul(A, B.T), C)
|
||||
num_warps = 4
|
||||
grid = (triton.cdiv(M, BLOCK_M), 1)
|
||||
gemm_fusion_kernel[grid](A, B, C, E, M, N, K,
|
||||
A.stride(0), A.stride(1), B.stride(0), B.stride(
|
||||
1), C.stride(0), C.stride(1), E.stride(0), E.stride(1),
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, num_warps=num_warps)
|
||||
gemm_fusion_kernel[grid](
|
||||
A, B, C, E, M, N, K, #
|
||||
A.stride(0), A.stride(1), #
|
||||
B.stride(0), B.stride(1), #
|
||||
C.stride(0), C.stride(1), #
|
||||
E.stride(0), E.stride(1), #
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, #
|
||||
num_warps=num_warps)
|
||||
|
||||
torch.testing.assert_close(ref_out, E, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def batched_gemm_fusion(
|
||||
Q, K, V, Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, NH, N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
def batched_gemm_fusion(Q, K, V, Out, #
|
||||
stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vk, stride_vn, #
|
||||
stride_oz, stride_oh, stride_om, stride_on, #
|
||||
Z, NH, N_CTX, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
|
||||
BLOCK_N: tl.constexpr):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
q_tile_ptr = tl.make_block_ptr(base=Q,
|
||||
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qz, stride_qh, stride_qm, stride_qk),
|
||||
offsets=(off_hz // NH, off_hz % NH, start_m, 0),
|
||||
block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL),
|
||||
order=(3, 2, 1, 0))
|
||||
k_tile_ptr = tl.make_block_ptr(base=K,
|
||||
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_kz, stride_kh, stride_kn, stride_kk),
|
||||
offsets=(off_hz // NH, off_hz % NH, 0, 0),
|
||||
block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL),
|
||||
order=(3, 2, 1, 0))
|
||||
v_tile_ptr = tl.make_block_ptr(base=V,
|
||||
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_vz, stride_vh, stride_vk, stride_vn),
|
||||
offsets=(off_hz // NH, off_hz % NH, 0, 0),
|
||||
block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL),
|
||||
order=(3, 2, 1, 0))
|
||||
o_tile_ptr = tl.make_block_ptr(base=Out,
|
||||
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_oz, stride_oh, stride_om, stride_on),
|
||||
offsets=(off_hz // NH, off_hz % NH, start_m, 0),
|
||||
block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL),
|
||||
order=(3, 2, 1, 0))
|
||||
q_tile_ptr = tl.make_block_ptr(
|
||||
base=Q,
|
||||
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qz, stride_qh, stride_qm, stride_qk),
|
||||
offsets=(off_hz // NH, off_hz % NH, start_m, 0),
|
||||
block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL),
|
||||
order=(3, 2, 1, 0),
|
||||
)
|
||||
k_tile_ptr = tl.make_block_ptr(
|
||||
base=K,
|
||||
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_kz, stride_kh, stride_kn, stride_kk),
|
||||
offsets=(off_hz // NH, off_hz % NH, 0, 0),
|
||||
block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL),
|
||||
order=(3, 2, 1, 0),
|
||||
)
|
||||
v_tile_ptr = tl.make_block_ptr(
|
||||
base=V,
|
||||
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_vz, stride_vh, stride_vk, stride_vn),
|
||||
offsets=(off_hz // NH, off_hz % NH, 0, 0),
|
||||
block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL),
|
||||
order=(3, 2, 1, 0),
|
||||
)
|
||||
o_tile_ptr = tl.make_block_ptr(
|
||||
base=Out,
|
||||
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_oz, stride_oh, stride_om, stride_on),
|
||||
offsets=(off_hz // NH, off_hz % NH, start_m, 0),
|
||||
block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL),
|
||||
order=(3, 2, 1, 0),
|
||||
)
|
||||
|
||||
q = tl.load(q_tile_ptr, boundary_check=(0, 1, 2, 3))
|
||||
q = tl.view(q, (BLOCK_M, BLOCK_DMODEL))
|
||||
@@ -155,12 +163,13 @@ def test_batched_gemm_fusion():
|
||||
ref_out = torch.matmul(torch.matmul(A, BT), C)
|
||||
num_warps = 4
|
||||
grid = (triton.cdiv(N_CTX, BLOCK_M), B * NH)
|
||||
batched_gemm_fusion[grid](A, B, C, E,
|
||||
A.stride(0), A.stride(1), A.stride(2), A.stride(3),
|
||||
B.stride(0), B.stride(1), B.stride(2), B.stride(3),
|
||||
C.stride(0), C.stride(1), C.stride(2), C.stride(3),
|
||||
E.stride(0), E.stride(1), E.stride(2), E.stride(3),
|
||||
Z, NH, N_CTX,
|
||||
BLOCK_M, BLOCK_DMODEL, BLOCK_N, num_warps=num_warps)
|
||||
batched_gemm_fusion[grid](
|
||||
A, B, C, E, #
|
||||
A.stride(0), A.stride(1), A.stride(2), A.stride(3), #
|
||||
B.stride(0), B.stride(1), B.stride(2), B.stride(3), #
|
||||
C.stride(0), C.stride(1), C.stride(2), C.stride(3), #
|
||||
E.stride(0), E.stride(1), E.stride(2), E.stride(3), #
|
||||
Z, NH, N_CTX, #
|
||||
BLOCK_M, BLOCK_DMODEL, BLOCK_N, num_warps=num_warps)
|
||||
|
||||
torch.testing.assert_close(ref_out, E, atol=1e-2, rtol=0)
|
||||
|
||||
@@ -24,10 +24,8 @@ def add_kernel(
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
|
||||
x_block_ptr = tl.make_block_ptr(
|
||||
base=x_ptr, shape=(n_elements, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),
|
||||
block_shape=(BLOCK_SIZE, ), order=(0, )
|
||||
)
|
||||
x_block_ptr = tl.make_block_ptr(base=x_ptr, shape=(n_elements, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),
|
||||
block_shape=(BLOCK_SIZE, ), order=(0, ))
|
||||
x = tl.load(x_block_ptr, boundary_check=(0, ), padding_option='zero')
|
||||
|
||||
y = tl.load(y_ptr + offsets, mask=mask)
|
||||
@@ -36,9 +34,7 @@ def add_kernel(
|
||||
|
||||
|
||||
@pytest.mark.parametrize('SIZE,BLOCK_SIZE,dtype_str',
|
||||
[(98432, 1024, dtype_str)
|
||||
for dtype_str in ['float16', 'float32']
|
||||
])
|
||||
[(98432, 1024, dtype_str) for dtype_str in ['float16', 'float32']])
|
||||
def test_add(SIZE, BLOCK_SIZE, dtype_str):
|
||||
dtype = dtype_mapping[dtype_str]
|
||||
output = torch.empty(SIZE, device='cuda', dtype=dtype)
|
||||
@@ -46,7 +42,8 @@ def test_add(SIZE, BLOCK_SIZE, dtype_str):
|
||||
y = torch.randn(SIZE, device='cuda', dtype=dtype)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(SIZE, meta['BLOCK_SIZE']),)
|
||||
return (triton.cdiv(SIZE, meta['BLOCK_SIZE']), )
|
||||
|
||||
add_kernel[grid](x, y, output, SIZE, BLOCK_SIZE=BLOCK_SIZE)
|
||||
|
||||
output_torch = x + y
|
||||
@@ -64,25 +61,20 @@ def load_reduce_kernel(
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
x_ptr = tl.make_block_ptr(
|
||||
base=x_ptr, shape=(BLOCK_M, BLOCK_N), strides=(stride_xm, stride_xn),
|
||||
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)
|
||||
)
|
||||
x_ptr = tl.make_block_ptr(base=x_ptr, shape=(BLOCK_M, BLOCK_N), strides=(stride_xm, stride_xn), offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
|
||||
x = tl.load(x_ptr)
|
||||
y = tl.max(x, axis=1)
|
||||
tl.store(y_ptr + tl.arange(0, BLOCK_M), y)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('BLOCK_M,BLOCK_N,dtype_str',
|
||||
[(128, 64, dtype_str)
|
||||
for dtype_str in ['float16']
|
||||
])
|
||||
@pytest.mark.parametrize('BLOCK_M,BLOCK_N,dtype_str', [(128, 64, dtype_str) for dtype_str in ['float16']])
|
||||
def test_load_reduce(BLOCK_M, BLOCK_N, dtype_str):
|
||||
dtype = dtype_mapping[dtype_str]
|
||||
x = torch.randn((BLOCK_M, BLOCK_N), device='cuda', dtype=dtype)
|
||||
y = torch.empty((BLOCK_M, ), device='cuda', dtype=dtype)
|
||||
|
||||
load_reduce_kernel[(1,)](x, y, x.stride(0), x.stride(1), y.stride(0), BLOCK_M, BLOCK_N)
|
||||
load_reduce_kernel[(1, )](x, y, x.stride(0), x.stride(1), y.stride(0), BLOCK_M, BLOCK_N)
|
||||
|
||||
golden = x.max(dim=1)[0]
|
||||
torch.set_printoptions(profile='full')
|
||||
|
||||
@@ -18,7 +18,6 @@
|
||||
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
"""
|
||||
Fused Attention
|
||||
===============
|
||||
@@ -40,18 +39,17 @@ import triton.language as tl
|
||||
key=['Q', 'K', 'V'],
|
||||
)
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
L, M,
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, H, N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
def _fwd_kernel(Q, K, V, sm_scale, #
|
||||
L, M, #
|
||||
Out, #
|
||||
stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vk, stride_vn, #
|
||||
stride_oz, stride_oh, stride_om, stride_on, #
|
||||
Z, H, N_CTX, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
|
||||
BLOCK_N: tl.constexpr #
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
# initialize offsets
|
||||
@@ -116,11 +114,10 @@ def _fwd_kernel(
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_preprocess(
|
||||
Out, DO, L,
|
||||
NewDO, Delta,
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
||||
):
|
||||
def _bwd_preprocess(Out, DO, L, #
|
||||
NewDO, Delta, #
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr #
|
||||
):
|
||||
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, D_HEAD)
|
||||
# load
|
||||
@@ -136,19 +133,18 @@ def _bwd_preprocess(
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_kernel(
|
||||
Q, K, V, sm_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
L, M,
|
||||
D,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
Z, H, N_CTX,
|
||||
num_block,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
def _bwd_kernel(Q, K, V, sm_scale, Out, DO, #
|
||||
DQ, DK, DV, #
|
||||
L, M, #
|
||||
D, #
|
||||
stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vk, stride_vn, #
|
||||
Z, H, N_CTX, #
|
||||
num_block, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
|
||||
BLOCK_N: tl.constexpr, #
|
||||
):
|
||||
off_hz = tl.program_id(0)
|
||||
off_z = off_hz // H
|
||||
off_h = off_hz % H
|
||||
@@ -240,16 +236,16 @@ class _attention(torch.autograd.Function):
|
||||
assert num_warps == 4
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
L, m,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
q, k, v, sm_scale, #
|
||||
L, m, #
|
||||
o, #
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
|
||||
q.shape[0], q.shape[1], q.shape[2], #
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK, #
|
||||
BLOCK_DMODEL=Lk #
|
||||
)
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
@@ -269,24 +265,23 @@ class _attention(torch.autograd.Function):
|
||||
do_scaled = torch.empty_like(do)
|
||||
delta = torch.empty_like(l)
|
||||
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
||||
o, do, l,
|
||||
do_scaled, delta,
|
||||
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
_bwd_kernel[(ctx.grid[1],)](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dq, dk, dv,
|
||||
l, m,
|
||||
delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
ctx.grid[0],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
||||
num_stages=1,
|
||||
o, do, l, #
|
||||
do_scaled, delta, #
|
||||
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL)
|
||||
_bwd_kernel[(ctx.grid[1], )](
|
||||
q, k, v, ctx.sm_scale, #
|
||||
o, do_scaled, #
|
||||
dq, dk, dv, #
|
||||
l, m, #
|
||||
delta, #
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
|
||||
q.shape[0], q.shape[1], q.shape[2], #
|
||||
ctx.grid[0], #
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK, #
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, #
|
||||
num_warps=8, num_stages=1 #
|
||||
)
|
||||
return dq, dk, dv, None
|
||||
|
||||
@@ -339,19 +334,19 @@ except BaseException:
|
||||
|
||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
||||
# vary seq length for fixed head and batch=4
|
||||
configs = [triton.testing.Benchmark(
|
||||
x_names=['N_CTX'],
|
||||
# x_vals=[2**i for i in range(10, 14)],
|
||||
x_vals=[2**i for i in range(10, 11)],
|
||||
line_arg='provider',
|
||||
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
|
||||
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
|
||||
styles=[('red', '-'), ('blue', '-')],
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
||||
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
|
||||
# ) for mode in ['fwd', 'bwd']]
|
||||
) for mode in ['fwd']]
|
||||
configs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=['N_CTX'],
|
||||
# x_vals=[2**i for i in range(10, 14)],
|
||||
x_vals=[2**i
|
||||
for i in range(10, 11)], line_arg='provider', line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
|
||||
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), styles=[('red', '-'), ('blue', '-')], ylabel='ms',
|
||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
||||
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
|
||||
# ) for mode in ['fwd', 'bwd']]
|
||||
)
|
||||
for mode in ['fwd']
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
@@ -374,9 +369,8 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
|
||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
if provider == "flash":
|
||||
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
|
||||
cu_seqlens = torch.zeros(
|
||||
(BATCH + 1,), device=device, dtype=torch.int32)
|
||||
lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device)
|
||||
cu_seqlens = torch.zeros((BATCH + 1, ), device=device, dtype=torch.int32)
|
||||
cu_seqlens[1:] = lengths.cumsum(0)
|
||||
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
|
||||
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
|
||||
|
||||
@@ -29,14 +29,14 @@ import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def static_persistent_matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
NUM_SM: tl.constexpr,
|
||||
def static_persistent_matmul_kernel( #
|
||||
a_ptr, b_ptr, c_ptr, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
|
||||
NUM_SM: tl.constexpr #
|
||||
):
|
||||
start_tile = tl.program_id(axis=0)
|
||||
m_tiles = tl.cdiv(M, BLOCK_M)
|
||||
@@ -68,14 +68,14 @@ def static_persistent_matmul_kernel(
|
||||
|
||||
|
||||
@triton.jit
|
||||
def static_persistent_tma_matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
NUM_SM: tl.constexpr,
|
||||
def static_persistent_tma_matmul_kernel( #
|
||||
a_ptr, b_ptr, c_ptr, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
|
||||
NUM_SM: tl.constexpr #
|
||||
):
|
||||
start_tile = tl.program_id(axis=0)
|
||||
m_tiles = tl.cdiv(M, BLOCK_M)
|
||||
@@ -88,8 +88,10 @@ def static_persistent_tma_matmul_kernel(
|
||||
|
||||
block_offset_m = pre_pid_m * BLOCK_M
|
||||
block_offset_n = pre_pid_n * BLOCK_N
|
||||
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
|
||||
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
|
||||
offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
|
||||
offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
|
||||
for tile_id in range(start_tile, num_tiles, NUM_SM):
|
||||
pid_m = tile_id // n_tiles
|
||||
pid_n = tile_id % n_tiles
|
||||
@@ -114,21 +116,23 @@ def static_persistent_tma_matmul_kernel(
|
||||
pre_pid_n = pid_n
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA',
|
||||
[(*shape, use_tma)
|
||||
for shape in [
|
||||
[4096, 4096, 64, 64, 64, 16, 4, 1, False, True],
|
||||
[4096, 4096, 64, 64, 64, 32, 4, 1, False, True],
|
||||
[4096, 4096, 64, 256, 64, 16, 4, 1, False, True],
|
||||
[4096, 4096, 64, 128, 128, 16, 4, 1, False, True],
|
||||
# TODO: fix issue for 8-warp persistent kernel
|
||||
# [4096, 4096, 64, 128, 128, 16, 8, 1, False, True],
|
||||
# [4096, 4096, 64, 128, 256, 16, 8, 1, False, True],
|
||||
]
|
||||
for use_tma in [False, True]
|
||||
])
|
||||
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA', [(
|
||||
*shape, use_tma
|
||||
) for shape in [
|
||||
[4096, 4096, 64, 64, 64, 16, 4, 1, False, True],
|
||||
[4096, 4096, 64, 64, 64, 32, 4, 1, False, True
|
||||
],
|
||||
[4096, 4096, 64, 256, 64, 16, 4, 1, False, True
|
||||
],
|
||||
[4096, 4096, 64, 128, 128, 16, 4, 1, False, True
|
||||
],
|
||||
# TODO: fix issue for 8-warp persistent kernel
|
||||
# [4096, 4096, 64, 128, 128, 16, 8, 1, False, True],
|
||||
# [4096, 4096, 64, 128, 256, 16, 8, 1, False, True],
|
||||
] for use_tma in [False, True]])
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, TRANS_A, TRANS_B, USE_TMA):
|
||||
def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS,
|
||||
TRANS_A, TRANS_B, USE_TMA):
|
||||
if (TRANS_A):
|
||||
a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
@@ -141,25 +145,33 @@ def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLO
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
||||
|
||||
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
|
||||
grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),)
|
||||
grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), )
|
||||
|
||||
if USE_TMA:
|
||||
static_persistent_tma_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS)
|
||||
static_persistent_tma_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0),
|
||||
stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS,
|
||||
num_ctas=NUM_CTAS)
|
||||
else:
|
||||
static_persistent_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS)
|
||||
static_persistent_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0),
|
||||
stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS,
|
||||
num_ctas=NUM_CTAS)
|
||||
|
||||
th_c = torch.matmul(a, b)
|
||||
torch.testing.assert_close(th_c, c, atol=1e-2, rtol=0, check_dtype=False)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def warp_specialized_matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
def warp_specialized_matmul_kernel( #
|
||||
a_ptr, b_ptr, c_ptr, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
|
||||
):
|
||||
tid = tl.program_id(axis=0)
|
||||
n_tiles = tl.cdiv(N, BLOCK_N)
|
||||
@@ -193,13 +205,13 @@ def warp_specialized_matmul_kernel(
|
||||
|
||||
|
||||
@triton.jit
|
||||
def tma_warp_specialized_matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
def tma_warp_specialized_matmul_kernel( #
|
||||
a_ptr, b_ptr, c_ptr, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
|
||||
):
|
||||
tid = tl.program_id(axis=0)
|
||||
n_tiles = tl.cdiv(N, BLOCK_N)
|
||||
@@ -232,8 +244,7 @@ def tma_warp_specialized_matmul_kernel(
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA',
|
||||
[(*shape, use_tma)
|
||||
for shape in [
|
||||
[(*shape, use_tma) for shape in [
|
||||
[2048, 2048, 64, 64, 64, 16, 1, False, True],
|
||||
[4096, 4096, 64, 64, 64, 16, 1, False, True],
|
||||
[128, 4096, 64, 64, 64, 16, 1, False, True],
|
||||
@@ -257,9 +268,7 @@ def tma_warp_specialized_matmul_kernel(
|
||||
[4096, 4096, 128, 256, 128, 64, 4, False, True],
|
||||
[4096, 4096, 256, 128, 256, 64, 4, False, True],
|
||||
[4096, 4096, 256, 256, 256, 64, 4, False, True],
|
||||
]
|
||||
for use_tma in [False, True]
|
||||
])
|
||||
] for use_tma in [False, True]])
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B, USE_TMA):
|
||||
if (TRANS_A):
|
||||
@@ -274,29 +283,29 @@ def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K
|
||||
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
||||
|
||||
grid = lambda META: (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),)
|
||||
grid = lambda META: (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), )
|
||||
|
||||
if USE_TMA:
|
||||
tma_warp_specialized_matmul_kernel[grid](
|
||||
a, b, c,
|
||||
M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
BLOCK_M, BLOCK_N, BLOCK_K,
|
||||
num_warps=4,
|
||||
num_ctas=NUM_CTAS,
|
||||
a, b, c, #
|
||||
M, N, K, #
|
||||
a.stride(0), a.stride(1), #
|
||||
b.stride(0), b.stride(1), #
|
||||
c.stride(0), c.stride(1), #
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, #
|
||||
num_warps=4, #
|
||||
num_ctas=NUM_CTAS, #
|
||||
enable_warp_specialization=True)
|
||||
else:
|
||||
warp_specialized_matmul_kernel[grid](
|
||||
a, b, c,
|
||||
M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
BLOCK_M, BLOCK_N, BLOCK_K,
|
||||
num_warps=4,
|
||||
num_ctas=NUM_CTAS,
|
||||
a, b, c, #
|
||||
M, N, K, #
|
||||
a.stride(0), a.stride(1), #
|
||||
b.stride(0), b.stride(1), #
|
||||
c.stride(0), c.stride(1), #
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, #
|
||||
num_warps=4, #
|
||||
num_ctas=NUM_CTAS, #
|
||||
enable_warp_specialization=True)
|
||||
|
||||
th_c = torch.matmul(a, b)
|
||||
@@ -304,14 +313,14 @@ def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K
|
||||
|
||||
|
||||
@triton.jit
|
||||
def static_persistent_warp_specialized_matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
NUM_SM: tl.constexpr,
|
||||
def static_persistent_warp_specialized_matmul_kernel( #
|
||||
a_ptr, b_ptr, c_ptr, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
|
||||
NUM_SM: tl.constexpr #
|
||||
):
|
||||
start_tile = tl.program_id(axis=0)
|
||||
m_tiles = tl.cdiv(M, BLOCK_M)
|
||||
@@ -343,14 +352,14 @@ def static_persistent_warp_specialized_matmul_kernel(
|
||||
|
||||
|
||||
@triton.jit
|
||||
def static_persistent_tma_warp_specialized_matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
NUM_SM: tl.constexpr,
|
||||
def static_persistent_tma_warp_specialized_matmul_kernel( #
|
||||
a_ptr, b_ptr, c_ptr, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
|
||||
NUM_SM: tl.constexpr #
|
||||
):
|
||||
start_tile = tl.program_id(axis=0)
|
||||
m_tiles = tl.cdiv(M, BLOCK_M)
|
||||
@@ -363,8 +372,10 @@ def static_persistent_tma_warp_specialized_matmul_kernel(
|
||||
|
||||
block_offset_m = pre_pid_m * BLOCK_M
|
||||
block_offset_n = pre_pid_n * BLOCK_N
|
||||
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
|
||||
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
|
||||
offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
|
||||
offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
|
||||
for tile_id in range(start_tile, num_tiles, NUM_SM):
|
||||
pid_m = tile_id // n_tiles
|
||||
pid_n = tile_id % n_tiles
|
||||
@@ -390,8 +401,7 @@ def static_persistent_tma_warp_specialized_matmul_kernel(
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA',
|
||||
[(*shape, use_tma)
|
||||
for shape in [
|
||||
[(*shape, use_tma) for shape in [
|
||||
[2048, 2048, 64, 64, 64, 16, 1, False, True],
|
||||
[4096, 4096, 64, 64, 64, 16, 1, False, True],
|
||||
[128, 4096, 64, 64, 64, 16, 1, False, True],
|
||||
@@ -415,11 +425,10 @@ def static_persistent_tma_warp_specialized_matmul_kernel(
|
||||
[4096, 4096, 128, 256, 128, 64, 4, False, True],
|
||||
[4096, 4096, 256, 128, 256, 64, 4, False, True],
|
||||
[4096, 4096, 256, 256, 256, 64, 4, False, True],
|
||||
]
|
||||
for use_tma in [False, True]
|
||||
])
|
||||
] for use_tma in [False, True]])
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B, USE_TMA):
|
||||
def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B,
|
||||
USE_TMA):
|
||||
if (TRANS_A):
|
||||
a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
@@ -432,27 +441,22 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
||||
|
||||
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
|
||||
grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),)
|
||||
grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), )
|
||||
|
||||
if USE_TMA:
|
||||
static_persistent_tma_warp_specialized_matmul_kernel[grid](
|
||||
a, b, c,
|
||||
M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs,
|
||||
num_warps=4, num_ctas=NUM_CTAS,
|
||||
a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M,
|
||||
BLOCK_N, BLOCK_K, num_SMs, num_warps=4, num_ctas=NUM_CTAS, #
|
||||
enable_warp_specialization=True)
|
||||
else:
|
||||
static_persistent_warp_specialized_matmul_kernel[grid](
|
||||
a, b, c,
|
||||
M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs,
|
||||
num_warps=4, num_ctas=NUM_CTAS,
|
||||
a, b, c, #
|
||||
M, N, K, #
|
||||
a.stride(0), a.stride(1), #
|
||||
b.stride(0), b.stride(1), #
|
||||
c.stride(0), c.stride(1), #
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs, #
|
||||
num_warps=4, num_ctas=NUM_CTAS, #
|
||||
enable_warp_specialization=True)
|
||||
|
||||
th_c = torch.matmul(a, b)
|
||||
@@ -460,16 +464,15 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N
|
||||
|
||||
|
||||
@triton.jit
|
||||
def static_persistent_matmul_no_scf_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr,
|
||||
NUM_SM: tl.constexpr, USE_TMA_LOAD: tl.constexpr,
|
||||
):
|
||||
def static_persistent_matmul_no_scf_kernel(a_ptr, b_ptr, c_ptr, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
|
||||
FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr, #
|
||||
NUM_SM: tl.constexpr, USE_TMA_LOAD: tl.constexpr #
|
||||
):
|
||||
start_tile = tl.program_id(axis=0)
|
||||
m_tiles = tl.cdiv(M, BLOCK_M)
|
||||
n_tiles = tl.cdiv(N, BLOCK_N)
|
||||
@@ -487,7 +490,8 @@ def static_persistent_matmul_no_scf_kernel(
|
||||
offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
|
||||
if USE_TMA_EPILOGUE:
|
||||
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
|
||||
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
|
||||
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N),
|
||||
order=(1, 0))
|
||||
|
||||
for tile_id in range(start_tile, num_tiles, NUM_SM):
|
||||
pid_m = tile_id // n_tiles
|
||||
@@ -524,29 +528,27 @@ def static_persistent_matmul_no_scf_kernel(
|
||||
pre_pid_n = pid_n
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,USE_TMA_LOAD',
|
||||
itertools.chain(
|
||||
*[
|
||||
[
|
||||
# numCTAs = 1, no TMA multicast:
|
||||
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
# small M, N
|
||||
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
] for USE_TMA_EPILOGUE in [True, False]
|
||||
for USE_TMA_LOAD in [True, False]
|
||||
]))
|
||||
@pytest.mark.parametrize(
|
||||
'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,USE_TMA_LOAD',
|
||||
itertools.chain(*[[
|
||||
# numCTAs = 1, no TMA multicast:
|
||||
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
# small M, N
|
||||
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
] for USE_TMA_EPILOGUE in [True, False] for USE_TMA_LOAD in [True, False]]))
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_static_persistent_matmul_no_scf_kernel(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE, USE_TMA_LOAD):
|
||||
def test_static_persistent_matmul_no_scf_kernel(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE,
|
||||
USE_TMA_EPILOGUE, USE_TMA_LOAD):
|
||||
if (TRANS_A):
|
||||
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
@@ -564,46 +566,42 @@ def test_static_persistent_matmul_no_scf_kernel(M, N, K, NUM_CTAS, NUM_WARPS, TR
|
||||
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
|
||||
|
||||
# TODO: set `enable_warp_specialization=False` will lead to compilation error.
|
||||
static_persistent_matmul_no_scf_kernel[(num_SMs,)](a_ptr=a, b_ptr=b, c_ptr=c,
|
||||
M=M, N=N, K=K,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||
BLOCK_M=M if M < 128 else M // 2, BLOCK_N=N if N < 128 else N // 2, BLOCK_K=K, NUM_SM=num_SMs,
|
||||
num_warps=NUM_WARPS,
|
||||
num_ctas=NUM_CTAS,
|
||||
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"),
|
||||
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE,
|
||||
USE_TMA_LOAD=USE_TMA_LOAD,
|
||||
enable_warp_specialization=True)
|
||||
static_persistent_matmul_no_scf_kernel[(num_SMs, )](
|
||||
a_ptr=a, b_ptr=b, c_ptr=c, #
|
||||
M=M, N=N, K=K, #
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1), #
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1), #
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1), #
|
||||
BLOCK_M=M if M < 128 else M // 2, BLOCK_N=N if N < 128 else N // 2, BLOCK_K=K, NUM_SM=num_SMs, #
|
||||
num_warps=NUM_WARPS, #
|
||||
num_ctas=NUM_CTAS, #
|
||||
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), #
|
||||
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE, #
|
||||
USE_TMA_LOAD=USE_TMA_LOAD, #
|
||||
enable_warp_specialization=True)
|
||||
a_f32 = a.to(torch.float32)
|
||||
b_f32 = b.to(torch.float32)
|
||||
golden = torch.matmul(a_f32, b_f32)
|
||||
torch.set_printoptions(profile="full")
|
||||
assert_close(
|
||||
c,
|
||||
golden,
|
||||
rtol=1e-2,
|
||||
atol=1e-3,
|
||||
check_dtype=False)
|
||||
assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def full_static_persistent_matmul_kernel(
|
||||
a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_wm, stride_wn,
|
||||
stride_zm, stride_zn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
|
||||
out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr,
|
||||
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
|
||||
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
|
||||
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr,
|
||||
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr,
|
||||
NUM_SM: tl.constexpr
|
||||
):
|
||||
def full_static_persistent_matmul_kernel(a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_wm, stride_wn, #
|
||||
stride_zm, stride_zn, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr, #
|
||||
out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr, #
|
||||
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, #
|
||||
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, #
|
||||
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, #
|
||||
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr, #
|
||||
NUM_SM: tl.constexpr #
|
||||
):
|
||||
start_pid = tl.program_id(axis=0)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_N)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_M)
|
||||
@@ -618,15 +616,18 @@ def full_static_persistent_matmul_kernel(
|
||||
pre_block_offset_m = pre_pid_m * BLOCK_M
|
||||
pre_block_offset_n = pre_pid_n * BLOCK_N
|
||||
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
|
||||
offsets=(pre_block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(A_ORDER_0, A_ORDER_1))
|
||||
offsets=(pre_block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K),
|
||||
order=(A_ORDER_0, A_ORDER_1))
|
||||
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
|
||||
offsets=(0, pre_block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(B_ORDER_0, B_ORDER_1))
|
||||
offsets=(0, pre_block_offset_n), block_shape=(BLOCK_K, BLOCK_N),
|
||||
order=(B_ORDER_0, B_ORDER_1))
|
||||
w_tile_ptr = tl.make_block_ptr(base=w_ptr, shape=(N, N), strides=(stride_wm, stride_wn),
|
||||
offsets=(0, pre_block_offset_n), block_shape=(BLOCK_N, BLOCK_N), order=(0, 1))
|
||||
|
||||
if USE_TMA_STORE:
|
||||
z_block_ptr = tl.make_block_ptr(base=z_ptr, shape=(M, N), strides=(stride_zm, stride_zn),
|
||||
offsets=(pre_block_offset_m, pre_block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
|
||||
offsets=(pre_block_offset_m, pre_block_offset_n),
|
||||
block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
|
||||
|
||||
for tile_id in range(start_pid, num_tiles, NUM_SM):
|
||||
group_id = tile_id // num_pid_in_group
|
||||
@@ -694,136 +695,120 @@ def full_static_persistent_matmul_kernel(
|
||||
pre_pid_n = pid_n
|
||||
|
||||
|
||||
@pytest.mark.parametrize('BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS',
|
||||
[
|
||||
# corner shapes
|
||||
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws)
|
||||
for shape_w_c in [
|
||||
[4096, 1, 1024, False, False],
|
||||
[2048, 204, 1000, True, False],
|
||||
[16, 524288, 32, False, True],
|
||||
]
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for enable_ws in [True]
|
||||
] + [
|
||||
# softmax epilogue
|
||||
(*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
# softmax works for one CTA
|
||||
for shape_w_c in [
|
||||
[64, 64, 16, 4, 1, 64, 64, 64],
|
||||
[128, 128, 64, 4, 1, None, None, None],
|
||||
[16, 16, 64, 4, 1, 16, 16, 64],
|
||||
# TODO: enable when num_warps != 4 is supported.
|
||||
# [64, 64, 32, 8, 1, 64, 64, 64],
|
||||
[128, 128, 64, 4, 1, 128, 128, 128],
|
||||
]
|
||||
for epilogue in ['softmax']
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for trans_a in [False,]
|
||||
for trans_b in [True,]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [True]
|
||||
] + [
|
||||
# loop over tile shapes and transpose combinations
|
||||
(*shape_w_c, trans_a, trans_b, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[64, 64, 32, 4, 1, 128, 256, 64],
|
||||
[128, 128, 16, 4, 4, 512, 256, 64],
|
||||
[128, 256, 32, 4, 8, 256, 256, 192],
|
||||
[512, 256, 32, 4, 8, 1024, 256, 192],
|
||||
# BLOCK_K >= 128
|
||||
[64, 128, 128, 4, 1, 512, 256, 256],
|
||||
[128, 128, 128, 4, 1, 256, 256, 192],
|
||||
[128, 128, 128, 4, 2, 256, 256, 192],
|
||||
# small BLOCK_M and BLOCK_K
|
||||
[16, 32, 32, 4, 1, 128, 256, 64],
|
||||
[32, 32, 16, 4, 1, 256, 256, 192],
|
||||
[16, 32, 64, 4, 4, 512, 256, 64],
|
||||
]
|
||||
for out_dtype in ['float32',]
|
||||
for use_tma_store in [False,]
|
||||
for trans_a in [False, True]
|
||||
for trans_b in [False, True]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [True]
|
||||
] + [
|
||||
# loop over epilogues besides of softmax
|
||||
(*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[64, 64, 16, 4, 1, 128, 128, 64],
|
||||
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4] for num_ctas in [1, 2, 4]],
|
||||
# for chain-dot
|
||||
[128, 128, 64, 4, 1, None, None, None],
|
||||
[64, 64, 16, 4, 1, None, None, None],
|
||||
# small BLOCK_M and BLOCK_K
|
||||
[16, 16, 64, 4, 1, 128, 128, 64],
|
||||
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4] for num_ctas in [1, 2]],
|
||||
# # TODO: enable when num_warps != 4 is supported.
|
||||
# # repeat
|
||||
# # [64, 64, 32, 8, 1, 128, 256, 64],
|
||||
# # [64, 64, 16, 8, 2, 128, 128, 64],
|
||||
# irregular shape
|
||||
[128, 128, 64, 4, 1, 500, 200, 128],
|
||||
[128, 128, 64, 4, 1, 513, 193, 192],
|
||||
]
|
||||
for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot']
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for trans_a in [False,]
|
||||
for trans_b in [True,]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [True]
|
||||
if not (epilogue == 'chain-dot' and (shape_w_c[5] is not None or shape_w_c[0] != shape_w_c[1]))
|
||||
] + [
|
||||
# loop over instr shapes & pipeline stages
|
||||
(64, n, 16, 4, 1, 512, 256, 256, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for n in [16, 32, 64, 128, 256]
|
||||
for out_dtype in ['float32']
|
||||
for use_tma_store in [False,]
|
||||
for num_stages in [2, 4, 5, 7]
|
||||
for enable_ws in [True]
|
||||
] + [
|
||||
# irregular shapes
|
||||
(*shape_w_c, *shape, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[128, 128, 64, 4, 1],
|
||||
[256, 128, 64, 4, 2],
|
||||
[128, 128, 128, 4, 2]
|
||||
]
|
||||
for shape in [
|
||||
[512, 360, 1024],
|
||||
[360, 4096, 512],
|
||||
]
|
||||
for out_dtype in ['float32']
|
||||
for use_tma_store in [False, True]
|
||||
for num_stages in [3, 4]
|
||||
for enable_ws in [True]
|
||||
]
|
||||
)
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()
|
||||
[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS):
|
||||
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS])) in [
|
||||
'128-128-128-4-1-256-256-192-none-float32-True-3-True',
|
||||
]:
|
||||
@pytest.mark.parametrize(
|
||||
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS',
|
||||
[
|
||||
# corner shapes
|
||||
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws) for shape_w_c in [
|
||||
[4096, 1, 1024, False, False],
|
||||
[2048, 204, 1000, True, False],
|
||||
[16, 524288, 32, False, True],
|
||||
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for enable_ws in [True]
|
||||
] + [
|
||||
# softmax epilogue
|
||||
(*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
# softmax works for one CTA
|
||||
for shape_w_c in [
|
||||
[64, 64, 16, 4, 1, 64, 64, 64],
|
||||
[128, 128, 64, 4, 1, None, None, None],
|
||||
[16, 16, 64, 4, 1, 16, 16, 64],
|
||||
# TODO: enable when num_warps != 4 is supported.
|
||||
# [64, 64, 32, 8, 1, 64, 64, 64],
|
||||
[128, 128, 64, 4, 1, 128, 128, 128],
|
||||
]
|
||||
for epilogue in ['softmax']
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for trans_a in [False]
|
||||
for trans_b in [True]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [True]
|
||||
] + [
|
||||
# loop over tile shapes and transpose combinations
|
||||
(*shape_w_c, trans_a, trans_b, 'none', out_dtype, use_tma_store, num_stages, enable_ws) for shape_w_c in [
|
||||
[64, 64, 32, 4, 1, 128, 256, 64],
|
||||
[128, 128, 16, 4, 4, 512, 256, 64],
|
||||
[128, 256, 32, 4, 8, 256, 256, 192],
|
||||
[512, 256, 32, 4, 8, 1024, 256, 192],
|
||||
# BLOCK_K >= 128
|
||||
[64, 128, 128, 4, 1, 512, 256, 256],
|
||||
[128, 128, 128, 4, 1, 256, 256, 192],
|
||||
[128, 128, 128, 4, 2, 256, 256, 192],
|
||||
# small BLOCK_M and BLOCK_K
|
||||
[16, 32, 32, 4, 1, 128, 256, 64],
|
||||
[32, 32, 16, 4, 1, 256, 256, 192],
|
||||
[16, 32, 64, 4, 4, 512, 256, 64],
|
||||
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
|
||||
[False, True] for num_stages in [3] for enable_ws in [True]
|
||||
] + [
|
||||
# loop over epilogues besides of softmax
|
||||
(*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws) for shape_w_c in [
|
||||
[64, 64, 16, 4, 1, 128, 128, 64],
|
||||
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4] for num_ctas in [1, 2, 4]],
|
||||
# for chain-dot
|
||||
[128, 128, 64, 4, 1, None, None, None],
|
||||
[64, 64, 16, 4, 1, None, None, None],
|
||||
# small BLOCK_M and BLOCK_K
|
||||
[16, 16, 64, 4, 1, 128, 128, 64],
|
||||
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4] for num_ctas in [1, 2]],
|
||||
# # TODO: enable when num_warps != 4 is supported.
|
||||
# # repeat
|
||||
# # [64, 64, 32, 8, 1, 128, 256, 64],
|
||||
# # [64, 64, 16, 8, 2, 128, 128, 64],
|
||||
# irregular shape
|
||||
[128, 128, 64, 4, 1, 500, 200, 128],
|
||||
[128, 128, 64, 4, 1, 513, 193, 192],
|
||||
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'] for out_dtype in
|
||||
['float16', 'float32'] for use_tma_store in [False, True] for trans_a in [False] for trans_b in [True] for
|
||||
num_stages in [3] for enable_ws in [True] if not (epilogue == 'chain-dot' and
|
||||
(shape_w_c[5] is not None or shape_w_c[0] != shape_w_c[1]))
|
||||
] + [
|
||||
# loop over instr shapes & pipeline stages
|
||||
(64, n, 16, 4, 1, 512, 256, 256, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for n in [16, 32, 64, 128, 256]
|
||||
for out_dtype in ['float32']
|
||||
for use_tma_store in [False]
|
||||
for num_stages in [2, 4, 5, 7]
|
||||
for enable_ws in [True]
|
||||
] + [
|
||||
# irregular shapes
|
||||
(*shape_w_c, *shape, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [[128, 128, 64, 4, 1], [256, 128, 64, 4, 2], [128, 128, 128, 4, 2]]
|
||||
for shape in [
|
||||
[512, 360, 1024],
|
||||
[360, 4096, 512],
|
||||
]
|
||||
for out_dtype in ['float32']
|
||||
for use_tma_store in [False, True]
|
||||
for num_stages in [3, 4]
|
||||
for enable_ws in [True]
|
||||
])
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B,
|
||||
epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS):
|
||||
if '-'.join(
|
||||
map(str, [
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES,
|
||||
ENABLE_WS
|
||||
])) in [
|
||||
'128-128-128-4-1-256-256-192-none-float32-True-3-True',
|
||||
]:
|
||||
pytest.skip('out of resource: shared memory, Required: 263168')
|
||||
|
||||
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
|
||||
'16-32-64-4-4-512-256-64-True-False',
|
||||
'16-32-64-4-4-512-256-64-True-True',
|
||||
'16-32-64-4-4-512-256-64-False-False',
|
||||
'16-32-64-4-4-512-256-64-False-True',
|
||||
'16-32-64-4-4-512-256-64-True-False',
|
||||
'16-32-64-4-4-512-256-64-True-True',
|
||||
'16-32-64-4-4-512-256-64-False-False',
|
||||
'16-32-64-4-4-512-256-64-False-True',
|
||||
]:
|
||||
pytest.skip('shapePerCTA[1] < 16 not supported')
|
||||
|
||||
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
|
||||
'16-32-64-4-1-256-256-256-False',
|
||||
'16-32-64-4-2-256-256-256-False',
|
||||
'16-32-64-4-2-256-256-256-True',
|
||||
'16-32-64-8-2-256-256-256-False',
|
||||
'16-32-64-8-2-256-256-256-True',
|
||||
'16-32-64-4-1-256-256-256-False',
|
||||
'16-32-64-4-2-256-256-256-False',
|
||||
'16-32-64-4-2-256-256-256-True',
|
||||
'16-32-64-8-2-256-256-256-False',
|
||||
'16-32-64-8-2-256-256-256-True',
|
||||
]:
|
||||
pytest.skip('Known legacy issue, ldmatrix can only support x4')
|
||||
|
||||
@@ -893,37 +878,36 @@ def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WAR
|
||||
else:
|
||||
ref = d
|
||||
return ref
|
||||
|
||||
golden = process_epilogue(dot, bias, w, epilogue)
|
||||
|
||||
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
|
||||
|
||||
def grid(META):
|
||||
return (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),)
|
||||
return (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), )
|
||||
|
||||
full_static_persistent_matmul_kernel[grid](
|
||||
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z,
|
||||
M=M, N=N, K=K,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_wm=w.stride(0), stride_wn=w.stride(1),
|
||||
stride_zm=z.stride(0), stride_zn=z.stride(1),
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8,
|
||||
out_dtype=out_dtype,
|
||||
USE_TMA_STORE=USE_TMA_STORE,
|
||||
ADD_MATRIX=epilogue == 'add-matrix',
|
||||
ADD_ROWS=epilogue == 'add-rows',
|
||||
ADD_COLS=epilogue == 'add-cols',
|
||||
DO_SOFTMAX=epilogue == 'softmax',
|
||||
CHAIN_DOT=epilogue == 'chain-dot',
|
||||
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1],
|
||||
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1],
|
||||
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES,
|
||||
enable_warp_specialization=ENABLE_WS,
|
||||
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
|
||||
M=M, N=N, K=K, #
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1), #
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1), #
|
||||
stride_wm=w.stride(0), stride_wn=w.stride(1), #
|
||||
stride_zm=z.stride(0), stride_zn=z.stride(1), #
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
|
||||
out_dtype=out_dtype, #
|
||||
USE_TMA_STORE=USE_TMA_STORE, #
|
||||
ADD_MATRIX=epilogue == 'add-matrix', #
|
||||
ADD_ROWS=epilogue == 'add-rows', #
|
||||
ADD_COLS=epilogue == 'add-cols', #
|
||||
DO_SOFTMAX=epilogue == 'softmax', #
|
||||
CHAIN_DOT=epilogue == 'chain-dot', #
|
||||
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
|
||||
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
|
||||
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES, #
|
||||
enable_warp_specialization=ENABLE_WS, #
|
||||
NUM_SM=num_SMs)
|
||||
|
||||
torch.set_printoptions(profile="full")
|
||||
golden = torch.nn.functional.normalize(golden)
|
||||
z = torch.nn.functional.normalize(z)
|
||||
assert_close(z, golden,
|
||||
rtol=1e-2,
|
||||
atol=1e-3,
|
||||
check_dtype=False)
|
||||
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
|
||||
|
||||
@@ -19,7 +19,6 @@
|
||||
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
@@ -29,21 +28,21 @@ import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def matmul_tma_load_store(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
OUTPUT_F16: tl.constexpr
|
||||
def matmul_tma_load_store( #
|
||||
a_ptr, b_ptr, c_ptr, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
|
||||
OUTPUT_F16: tl.constexpr #
|
||||
):
|
||||
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
|
||||
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
|
||||
offsets=(0, 0), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
|
||||
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
|
||||
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
|
||||
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, 0),
|
||||
block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
|
||||
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
|
||||
a = tl.load(a_block_ptr)
|
||||
b = tl.load(b_block_ptr)
|
||||
|
||||
@@ -78,15 +77,15 @@ def test_tma_load_store(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_F
|
||||
if OUTPUT_F16:
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
|
||||
|
||||
matmul_tma_load_store[(1, 1)](a_ptr=a, b_ptr=b, c_ptr=c,
|
||||
M=M, N=N, K=K,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K,
|
||||
num_warps=NUM_WARPS,
|
||||
num_ctas=NUM_CTAS,
|
||||
OUTPUT_F16=OUTPUT_F16)
|
||||
matmul_tma_load_store[(1, 1)](
|
||||
a_ptr=a, b_ptr=b, c_ptr=c, #
|
||||
M=M, N=N, K=K, #
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1), #
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1), #
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1), #
|
||||
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, #
|
||||
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, #
|
||||
OUTPUT_F16=OUTPUT_F16)
|
||||
golden = torch.matmul(a, b)
|
||||
torch.set_printoptions(profile="full")
|
||||
assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
|
||||
|
||||
@@ -54,17 +54,13 @@ def test_tma_wgmma_64_64_16_f16(TTGIR, TRANS_A, TRANS_B):
|
||||
|
||||
ttgir_path = os.path.dirname(__file__) + "/" + TTGIR
|
||||
kernel = triton.compile(ttgir_path)
|
||||
kernel[(1, 1, 1)](a.data_ptr(), b.data_ptr(), c.data_ptr(),
|
||||
SIZE_M, SIZE_N, SIZE_K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0))
|
||||
kernel[(1, 1, 1)]( #
|
||||
a.data_ptr(), b.data_ptr(), c.data_ptr(), #
|
||||
SIZE_M, SIZE_N, SIZE_K, #
|
||||
a.stride(0), a.stride(1), #
|
||||
b.stride(0), b.stride(1), #
|
||||
c.stride(0))
|
||||
|
||||
golden = torch.matmul(a, b)
|
||||
torch.set_printoptions(profile="full", sci_mode=False)
|
||||
assert_close(
|
||||
c,
|
||||
golden,
|
||||
rtol=1e-2,
|
||||
atol=1e-3,
|
||||
check_dtype=False)
|
||||
assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
|
||||
|
||||
@@ -15,9 +15,9 @@ def kernel_device_assert(X, Y, BLOCK: tl.constexpr):
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_device_assert_scalar(X, Y, BLOCK: tl.constexpr):
|
||||
def kernel_assert_passes(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
# Trivial assert
|
||||
# Trivial assert, should not be an error.
|
||||
tl.device_assert(0 == 0, "x != 0")
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
@@ -48,6 +48,7 @@ def test_assert(func: str):
|
||||
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
if func == "device_assert":
|
||||
<<<<<<< HEAD
|
||||
kernel_device_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
|
||||
kernel_device_assert_scalar[(1,)](x, y, num_warps=2, BLOCK=shape[0])
|
||||
elif func == "no_debug":
|
||||
@@ -55,8 +56,32 @@ def test_assert(func: str):
|
||||
kernel_device_assert_no_debug[(1,)](x, y, num_warps=2, BLOCK=shape[0])
|
||||
elif func == "assert":
|
||||
kernel_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
|
||||
=======
|
||||
kernel_device_assert[(1, )](x, y, BLOCK=shape[0])
|
||||
if func == "device_assert_passes":
|
||||
# Assert passes; no error.
|
||||
kernel_assert_passes[(1, )](x, y, BLOCK=shape[0])
|
||||
elif func == "no_debug":
|
||||
# TRITON_DEBUG=1 can override the debug flag
|
||||
kernel_device_assert_no_debug[(1, )](x, y, BLOCK=shape[0])
|
||||
elif func == "assert":
|
||||
kernel_assert[(1, )](x, y, BLOCK=shape[0])
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
elif func == "static_assert":
|
||||
kernel_static_assert[(1,)](x, y, BLOCK=shape[0])
|
||||
kernel_static_assert[(1, )](x, y, BLOCK=shape[0])
|
||||
elif func == "double_assert":
|
||||
# Launching a different kernel after the first one asserted used to
|
||||
# segfault. What seems to have happened is:
|
||||
# - The first kernel is enqueued but doesn't run yet.
|
||||
# - We go to launch the second kernel. Because this is the first time
|
||||
# we're running it, we have to load the kernel into the GPU.
|
||||
# - Loading the kernel takes some time, during which the first launch
|
||||
# completes.
|
||||
# - Now the GPU is in an error state. We need to detect this inside
|
||||
# the kernel-launch/loading code and bail out properly. If we don't,
|
||||
# we segfault.
|
||||
kernel_device_assert[(1, )](x, y, BLOCK=shape[0])
|
||||
kernel_assert_passes[(1, )](x, y, BLOCK=shape[0])
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
@@ -116,11 +141,19 @@ def test_assert_nested(caller: str, callee: str):
|
||||
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
if caller == "none":
|
||||
<<<<<<< HEAD
|
||||
kernel_device_assert_nested[(1,)](x, y, num_warps=2, BLOCK=shape[0], jit_debug=callee)
|
||||
elif caller == "true":
|
||||
kernel_device_assert_nested_true[(1,)](x, y, num_warps=2, BLOCK=shape[0], jit_debug=callee)
|
||||
elif caller == "false":
|
||||
kernel_device_assert_nested_false[(1,)](x, y, num_warps=2, BLOCK=shape[0], jit_debug=callee)
|
||||
=======
|
||||
kernel_device_assert_nested[(1, )](x, y, BLOCK=shape[0], jit_debug=callee)
|
||||
elif caller == "true":
|
||||
kernel_device_assert_nested_true[(1, )](x, y, BLOCK=shape[0], jit_debug=callee)
|
||||
elif caller == "false":
|
||||
kernel_device_assert_nested_false[(1, )](x, y, BLOCK=shape[0], jit_debug=callee)
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
|
||||
@@ -4,9 +4,7 @@ import pytest
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--device", action="store", default='cuda'
|
||||
)
|
||||
parser.addoption("--device", action="store", default='cuda')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
@@ -10,21 +11,49 @@ import triton.language as tl
|
||||
@triton.jit
|
||||
def kernel_device_print(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.device_print("", x)
|
||||
tl.device_print("x: ", x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_print(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
print("", x)
|
||||
# Triton should add a space after this prefix.
|
||||
print("x:", x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_static_print(X, Y, BLOCK: tl.constexpr):
|
||||
def kernel_device_print_large(
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
x = tl.full([BLOCK_M, BLOCK_N], 1, tl.int32)
|
||||
# Triton should change this prefix to "x: ".
|
||||
tl.device_print("x ", x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_print_multiple_args(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.static_print(x)
|
||||
y = tl.full((BLOCK, ), 1, tl.int32)
|
||||
print("", x, y)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_device_print_multiple_args(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
y = tl.full((BLOCK, ), 1, tl.int32)
|
||||
tl.device_print("", x, y)
|
||||
tl.store(Y + tl.arange(0, BLOCK), y)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_static_print(X, Y, BLOCK: tl.constexpr, PLACEHOLDER: tl.constexpr):
|
||||
# This function takes an extra value as a tl.constexpr so this kernel is not
|
||||
# cached. This way the static print is run every time.
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.static_print("", x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@@ -33,21 +62,36 @@ def kernel_no_arg_print():
|
||||
print("", tl.program_id(0))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_print_no_arg():
|
||||
print("no arg")
|
||||
|
||||
|
||||
def test_print(func: str, data_type: str):
|
||||
shape = (128, )
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda').to(getattr(torch, data_type))
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
if func == "device_print":
|
||||
kernel_device_print[(1,)](x, y, BLOCK=shape[0])
|
||||
kernel_device_print[(1, )](x, y, BLOCK=shape[0])
|
||||
elif func == "print":
|
||||
kernel_print[(1,)](x, y, BLOCK=shape[0])
|
||||
kernel_print[(1, )](x, y, BLOCK=shape[0])
|
||||
elif func == "device_print_large":
|
||||
kernel_device_print_large[(1, 2)](BLOCK_M=64, BLOCK_N=128)
|
||||
elif func == "print_multiple_args":
|
||||
kernel_print_multiple_args[(1, )](x, y, BLOCK=shape[0])
|
||||
elif func == "device_print_multiple_args":
|
||||
kernel_device_print_multiple_args[(1, )](x, y, BLOCK=shape[0])
|
||||
elif func == "static_print":
|
||||
kernel_static_print[(1,)](x, y, BLOCK=shape[0])
|
||||
kernel_static_print[(1, )](x, y, BLOCK=shape[0], PLACEHOLDER=uuid.uuid4())
|
||||
elif func == "no_arg_print":
|
||||
kernel_no_arg_print[(1,)](num_warps=4)
|
||||
kernel_no_arg_print[(1, )](num_warps=4)
|
||||
elif func == "print_no_arg":
|
||||
kernel_print_no_arg[(1, )](num_warps=4)
|
||||
else:
|
||||
assert f"Unknown kernel: {func}"
|
||||
|
||||
if func != "no_arg_print":
|
||||
if func != "print_no_arg" and func != "no_arg_print" and func != "device_print_large" and \
|
||||
func != "print_multiple_args" and func != "device_print_multiple_args":
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
@@ -14,8 +13,8 @@ def test_annotations(device):
|
||||
pass
|
||||
|
||||
x = torch.empty(1, device=device)
|
||||
_kernel[(1,)](x, x.shape[0], 32)
|
||||
_kernel[(1, )](x, x.shape[0], 32)
|
||||
try:
|
||||
_kernel[(1,)](x.shape[0], x.shape[0], 32)
|
||||
_kernel[(1, )](x.shape[0], x.shape[0], 32)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@@ -17,10 +17,12 @@ def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, padding_option:
|
||||
tl.store(b_block_ptr, a, boundary_check=(0, ))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, n, padding_option",
|
||||
[(dtype_str, n, padding) for dtype_str in ("bool", "int16", "float16")
|
||||
for n in (64, 128, 256, 512, 1024)
|
||||
for padding in ("zero", "nan")])
|
||||
@pytest.mark.parametrize("dtype_str, n, padding_option", [ #
|
||||
(dtype_str, n, padding)
|
||||
for dtype_str in ("bool", "int16", "float16")
|
||||
for n in (64, 128, 256, 512, 1024)
|
||||
for padding in ("zero", "nan") #
|
||||
])
|
||||
def test_block_copy(dtype_str, n, padding_option):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if torch.version.hip is None and capability[0] >= 9:
|
||||
@@ -35,31 +37,31 @@ def test_block_copy(dtype_str, n, padding_option):
|
||||
a = torch.randn((n, ), device="cuda", dtype=dtype)
|
||||
b = torch.zeros((n, ), device="cuda", dtype=dtype)
|
||||
|
||||
grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),)
|
||||
grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]), )
|
||||
block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64, padding_option=padding_option)
|
||||
|
||||
assert torch.all(a[0: n // 2] == b[0: n // 2])
|
||||
assert torch.all(a[0:n // 2] == b[0:n // 2])
|
||||
if padding_option == "zero":
|
||||
assert torch.all(b[n // 2: n] == 0)
|
||||
assert torch.all(b[n // 2:n] == 0)
|
||||
else:
|
||||
assert torch.all(torch.isnan(b[n // 2: n]))
|
||||
assert torch.all(torch.isnan(b[n // 2:n]))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def matmul_no_scf_with_advance_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr
|
||||
def matmul_no_scf_with_advance_kernel( #
|
||||
a_ptr, b_ptr, c_ptr, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr #
|
||||
):
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
|
||||
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
|
||||
offsets=(0, 0), block_shape=(BLOCK_K, BLOCK_N), order=(1, 0))
|
||||
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, 0),
|
||||
block_shape=(BLOCK_K, BLOCK_N), order=(1, 0))
|
||||
# Below two lines are just for testing negative offsets for the `advance` API, which could be removed
|
||||
a_block_ptr = tl.advance(a_block_ptr, (BLOCK_M, -BLOCK_K))
|
||||
a_block_ptr = tl.advance(a_block_ptr, (-BLOCK_M, BLOCK_K))
|
||||
@@ -71,14 +73,12 @@ def matmul_no_scf_with_advance_kernel(
|
||||
tl.store(c_ptrs, c)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape, num_warps", [
|
||||
(shape, num_warps)
|
||||
for shape in [
|
||||
@pytest.mark.parametrize("shape, num_warps", [ #
|
||||
(shape, num_warps) for shape in [
|
||||
[64, 64, 16],
|
||||
[64, 64, 32],
|
||||
[64, 64, 64],
|
||||
]
|
||||
for num_warps in [4, 8]
|
||||
] for num_warps in [4, 8]
|
||||
])
|
||||
def test_block_ptr_matmul_no_scf(shape, num_warps):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
@@ -91,12 +91,13 @@ def test_block_ptr_matmul_no_scf(shape, num_warps):
|
||||
c = torch.empty((m, n), device="cuda", dtype=torch.float32)
|
||||
|
||||
grid = lambda META: (1, )
|
||||
matmul_no_scf_with_advance_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
|
||||
M=m, N=n, K=k,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||
BLOCK_M=m, BLOCK_N=n, BLOCK_K=k,
|
||||
num_warps=num_warps)
|
||||
matmul_no_scf_with_advance_kernel[grid](
|
||||
a_ptr=a, b_ptr=b, c_ptr=c, #
|
||||
M=m, N=n, K=k, #
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1), #
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1), #
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1), #
|
||||
BLOCK_M=m, BLOCK_N=n, BLOCK_K=k, #
|
||||
num_warps=num_warps)
|
||||
golden = torch.matmul(a, b)
|
||||
torch.testing.assert_close(c, golden, check_dtype=False)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -12,6 +12,7 @@ import triton.language as tl
|
||||
|
||||
|
||||
class PhiloxConfig:
|
||||
|
||||
def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE):
|
||||
self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE)
|
||||
self.PHILOX_ROUND_B = np.array(PHILOX_ROUND_B, dtype=DTYPE)
|
||||
@@ -40,6 +41,7 @@ PHILOX_64 = PhiloxConfig(
|
||||
|
||||
|
||||
class CustomPhilox4x:
|
||||
|
||||
def __init__(self, seed, config):
|
||||
self._config = config
|
||||
seed = self._into_pieces(seed)
|
||||
@@ -92,6 +94,7 @@ class CustomPhilox4x:
|
||||
|
||||
|
||||
class CustomPhilox(CustomPhilox4x):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.buffer = []
|
||||
@@ -111,10 +114,9 @@ BLOCK = 1024
|
||||
# test generation of random uint32
|
||||
|
||||
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in ['10', '4,53', '10000']
|
||||
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
|
||||
)
|
||||
@pytest.mark.parametrize('size, seed', [(size, seed)
|
||||
for size in ['10', '4,53', '10000']
|
||||
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]])
|
||||
def test_randint(size, seed, device):
|
||||
size = list(map(int, size.split(',')))
|
||||
|
||||
@@ -123,10 +125,11 @@ def test_randint(size, seed, device):
|
||||
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
|
||||
rand = tl.randint(seed, offset)
|
||||
tl.store(X + offset, rand, mask=offset < N)
|
||||
|
||||
# triton result
|
||||
x = torch.empty(size, dtype=torch.int32, device=device)
|
||||
N = x.numel()
|
||||
grid = (triton.cdiv(N, BLOCK),)
|
||||
grid = (triton.cdiv(N, BLOCK), )
|
||||
kernel[grid](x, N, seed)
|
||||
out_tri = x.cpu().numpy().astype(np.uint32).flatten().tolist()
|
||||
# reference result
|
||||
@@ -134,44 +137,44 @@ def test_randint(size, seed, device):
|
||||
out_ref = [gen.random_raw()[0] for _ in out_tri]
|
||||
assert out_tri == out_ref
|
||||
|
||||
|
||||
# test uniform PRNG
|
||||
|
||||
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in [1000000]
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
@pytest.mark.parametrize('size, seed', [(size, seed) for size in [1000000] for seed in [0, 42, 124, 54]])
|
||||
def test_rand(size, seed, device):
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, N, seed):
|
||||
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
|
||||
rand = tl.rand(seed, offset)
|
||||
tl.store(X + offset, rand, mask=offset < N)
|
||||
|
||||
# triton result
|
||||
x = torch.empty(size, dtype=torch.float32, device=device)
|
||||
N = x.numel()
|
||||
grid = (triton.cdiv(N, BLOCK),)
|
||||
grid = (triton.cdiv(N, BLOCK), )
|
||||
kernel[grid](x, N, seed)
|
||||
assert all((x >= 0) & (x <= 1))
|
||||
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
|
||||
|
||||
|
||||
# test normal PRNG
|
||||
|
||||
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in [1000000]
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
@pytest.mark.parametrize('size, seed', [(size, seed) for size in [1000000] for seed in [0, 42, 124, 54]])
|
||||
def test_randn(size, seed, device):
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, N, seed):
|
||||
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
|
||||
rand = tl.randn(seed, offset)
|
||||
tl.store(X + offset, rand, mask=offset < N)
|
||||
|
||||
# triton result
|
||||
x = torch.empty(size, dtype=torch.float32, device=device)
|
||||
N = x.numel()
|
||||
grid = (triton.cdiv(N, BLOCK),)
|
||||
grid = (triton.cdiv(N, BLOCK), )
|
||||
kernel[grid](x, N, seed)
|
||||
assert abs(x.mean()) < 1e-2
|
||||
assert abs(x.std() - 1) < 1e-2
|
||||
@@ -179,7 +182,9 @@ def test_randn(size, seed, device):
|
||||
|
||||
# tl.rand() should never produce >=1.0
|
||||
|
||||
|
||||
def test_rand_limits(device):
|
||||
|
||||
@triton.jit
|
||||
def kernel(input, output, n: tl.constexpr):
|
||||
idx = tl.arange(0, n)
|
||||
@@ -192,7 +197,7 @@ def test_rand_limits(device):
|
||||
torch.iinfo(torch.int32).max,
|
||||
], dtype=torch.int32, device=device)
|
||||
output = torch.empty(2, dtype=torch.float32, device=device)
|
||||
kernel[(1,)](min_max_int32, output, 2)
|
||||
kernel[(1, )](min_max_int32, output, 2)
|
||||
|
||||
assert output[0] == output[1]
|
||||
assert 1.0 - torch.finfo(torch.float32).eps <= output[0].item() < 1.0
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from collections import Counter
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -9,53 +11,88 @@ print_path = os.path.join(dir_path, "print_helper.py")
|
||||
assert_path = os.path.join(dir_path, "assert_helper.py")
|
||||
|
||||
# TODO: bfloat16 after LLVM-15
|
||||
assert_types = ["device_assert", "assert", "static_assert", "no_debug"]
|
||||
assert_types = ["device_assert", "device_assert_passes", "assert", "static_assert", "no_debug", "double_assert"]
|
||||
nested_types = [(caller, callee) for caller in ["true", "false", "none"] for callee in ["true", "false", "none"]]
|
||||
torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("func_type, data_type",
|
||||
[("device_print", data_type) for data_type in torch_types] + [("print", "int32"), ("static_print", "int32"), ("no_arg_print", "int32")])
|
||||
# TODO: Print with multiple operands
|
||||
@pytest.mark.parametrize("func_type, data_type", [("device_print", data_type) for data_type in torch_types] + [
|
||||
("print", "int32"),
|
||||
("static_print", "int32"),
|
||||
("no_arg_print", "int32"),
|
||||
("print_no_arg", "int32"),
|
||||
("device_print_large", "int32"),
|
||||
("print_multiple_args", "int32"),
|
||||
("device_print_multiple_args", "int32"),
|
||||
])
|
||||
def test_print(func_type: str, data_type: str):
|
||||
proc = subprocess.Popen([sys.executable, print_path, func_type, data_type], stdout=subprocess.PIPE, shell=False)
|
||||
outs, _ = proc.communicate()
|
||||
outs = outs.split()
|
||||
new_lines = set()
|
||||
for line in outs:
|
||||
try:
|
||||
value = line
|
||||
if func_type != "static_print":
|
||||
value = int(float(line))
|
||||
new_lines.add(value)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
if func_type != "static_print" and func_type != "no_arg_print":
|
||||
outs = [line for line in outs.decode("UTF-8").split("\n") if line]
|
||||
|
||||
# Format is
|
||||
# pid (<x>, <y>, <z>) idx (<i1>, <i2>, ...) <prefix> (operand <n>) <elem>
|
||||
expected_lines = Counter()
|
||||
if func_type == "print" or func_type == "device_print":
|
||||
for i in range(128):
|
||||
assert i in new_lines
|
||||
else:
|
||||
assert len(new_lines) == 1
|
||||
line = f"pid (0, 0, 0) idx ({i:3}) x: {i}"
|
||||
if data_type.startswith("float"):
|
||||
line += ".000000"
|
||||
expected_lines[line] = 1
|
||||
elif func_type == "static_print":
|
||||
expected_lines[" int32[constexpr[128]]"] = 1
|
||||
elif func_type == "no_arg_print":
|
||||
expected_lines["pid (0, 0, 0) idx (): 0"] = 128
|
||||
elif func_type == "print_no_arg":
|
||||
expected_lines["pid (0, 0, 0) no arg"] = 128
|
||||
elif func_type == "device_print_large":
|
||||
for i, j, k in itertools.product(range(2), range(64), range(128)):
|
||||
expected_lines[f"pid (0, {i}, 0) idx ({j:2}, {k:3}) x: 1"] = 1
|
||||
elif func_type == "print_multiple_args" or func_type == "device_print_multiple_args":
|
||||
for i in range(128):
|
||||
expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 0) {i}"] = 1
|
||||
expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 1) 1"] = 1
|
||||
|
||||
actual_lines = Counter()
|
||||
for line in outs:
|
||||
actual_lines[line] += 1
|
||||
|
||||
diff = Counter(actual_lines)
|
||||
diff.subtract(expected_lines)
|
||||
for line, delta in diff.items():
|
||||
if delta == 0:
|
||||
continue
|
||||
print(f'Expected line "{line}" {expected_lines[line]} time(s), but saw {actual_lines[line]} time(s)')
|
||||
assert all(delta == 0 for delta in diff.values())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("func_type", assert_types)
|
||||
def test_assert(func_type: str):
|
||||
os.environ["TRITON_DEBUG"] = "1"
|
||||
proc = subprocess.Popen([sys.executable, assert_path, func_type], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
|
||||
proc = subprocess.Popen([sys.executable, assert_path, func_type], stdout=subprocess.PIPE, stderr=subprocess.PIPE,
|
||||
shell=False)
|
||||
_, errs = proc.communicate()
|
||||
errs = errs.splitlines()
|
||||
num_errs = 0
|
||||
for err in errs:
|
||||
if "x != 0" in err.decode("utf-8"):
|
||||
num_errs += 1
|
||||
|
||||
# Check for segfaults.
|
||||
assert all("segmentation fault" not in line.decode("utf-8").lower() for line in errs)
|
||||
|
||||
os.environ["TRITON_DEBUG"] = "0"
|
||||
if func_type != "static_assert":
|
||||
assert num_errs == 127
|
||||
else:
|
||||
if func_type == "static_assert" or func_type == "device_assert_passes":
|
||||
assert num_errs == 0
|
||||
else:
|
||||
assert num_errs == 127
|
||||
|
||||
|
||||
@pytest.mark.parametrize("caller_type, callee_type", nested_types)
|
||||
def test_assert_nested(caller_type, callee_type):
|
||||
proc = subprocess.Popen([sys.executable, assert_path, caller_type, callee_type], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
|
||||
proc = subprocess.Popen([sys.executable, assert_path, caller_type, callee_type], stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE, shell=False)
|
||||
_, errs = proc.communicate()
|
||||
errs = errs.splitlines()
|
||||
num_errs = 0
|
||||
|
||||
@@ -68,8 +68,7 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
|
||||
b_ref = do_mask(b_ref) if is_dds else b_ref
|
||||
a_ref.retain_grad()
|
||||
b_ref.retain_grad()
|
||||
c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref,
|
||||
b_ref.transpose(2, 3) if TRANS_B else b_ref)
|
||||
c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref, b_ref.transpose(2, 3) if TRANS_B else b_ref)
|
||||
c_ref.backward(dc_ref)
|
||||
c_ref = do_sparsify(c_ref) if is_sdd else c_ref
|
||||
da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad
|
||||
@@ -172,7 +171,7 @@ def test_attention_fwd_bwd(
|
||||
value.retain_grad()
|
||||
attn_out = triton_attention(layout, block, query=query, key=key, value=value, scale=scale)
|
||||
# ad hoc loss
|
||||
loss = (attn_out ** 2).mean()
|
||||
loss = (attn_out**2).mean()
|
||||
loss.backward()
|
||||
grads = [query.grad, key.grad, value.grad]
|
||||
|
||||
@@ -189,7 +188,7 @@ def test_attention_fwd_bwd(
|
||||
probs = torch.softmax(scores, dim=-1)
|
||||
torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v)
|
||||
# ad hoc loss
|
||||
torch_loss = (torch_attn_out ** 2).mean()
|
||||
torch_loss = (torch_attn_out**2).mean()
|
||||
torch_loss.backward()
|
||||
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad]
|
||||
|
||||
@@ -209,8 +208,10 @@ def triton_attention(
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device)
|
||||
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device)
|
||||
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True,
|
||||
device=value.device)
|
||||
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False,
|
||||
device=value.device)
|
||||
sparse_softmax = triton.ops.blocksparse.softmax(layout, block, device=value.device)
|
||||
|
||||
w = sparse_dot_sdd_nt(query, key)
|
||||
|
||||
@@ -5,14 +5,13 @@ import triton
|
||||
import triton.ops
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N, dtype, mode",
|
||||
[
|
||||
(M, N, dtype, mode) for M in [1024, 821]
|
||||
for N in [512, 857, 1871, 2089, 8573, 31000]
|
||||
for dtype in ['float16', 'float32']
|
||||
for mode in ['forward', 'backward']
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize("M, N, dtype, mode", [ #
|
||||
(M, N, dtype, mode)
|
||||
for M in [1024, 821]
|
||||
for N in [512, 857, 1871, 2089, 8573, 31000]
|
||||
for dtype in ['float16', 'float32']
|
||||
for mode in ['forward', 'backward']
|
||||
])
|
||||
def test_op(M, N, dtype, mode):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 8 and dtype == "bfloat16":
|
||||
|
||||
@@ -5,10 +5,12 @@ import triton
|
||||
import triton.ops
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(2, 4, 512, 16),
|
||||
(2, 4, 512, 32),
|
||||
(2, 4, 512, 64),
|
||||
(2, 4, 512, 128)])
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ #
|
||||
(2, 4, 512, 16),
|
||||
(2, 4, 512, 32),
|
||||
(2, 4, 512, 64),
|
||||
(2, 4, 512, 128),
|
||||
])
|
||||
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize('causal', [True, False])
|
||||
@pytest.mark.parametrize('seq_par', [True, False])
|
||||
@@ -56,6 +58,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
|
||||
|
||||
# # triton implementation
|
||||
tri_out = triton.ops.attention(q, k, v, causal, sm_scale, seq_par)
|
||||
<<<<<<< HEAD
|
||||
# print(ref_out)
|
||||
# print(tri_out)
|
||||
if torch.version.hip is None:
|
||||
@@ -70,3 +73,74 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
|
||||
torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0)
|
||||
torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0)
|
||||
torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0)
|
||||
=======
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
atol = 1e-1 if dtype == torch.bfloat16 else 1e-2
|
||||
torch.testing.assert_close(ref_out, tri_out, atol=atol, rtol=0)
|
||||
torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0)
|
||||
torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0)
|
||||
torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0)
|
||||
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
HAS_FLASH = True
|
||||
except BaseException:
|
||||
HAS_FLASH = False
|
||||
|
||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
||||
# vary seq length for fixed head and batch=4
|
||||
configs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=['N_CTX'], x_vals=[2**i for i in range(10, 14)], line_arg='provider',
|
||||
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
|
||||
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), styles=[('red', '-'), ('blue', '-')], ylabel='ms',
|
||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-{casual}-{seq_par}', args={
|
||||
'H': N_HEADS,
|
||||
'BATCH': BATCH,
|
||||
'D_HEAD': D_HEAD,
|
||||
'dtype': torch.float16,
|
||||
'mode': mode,
|
||||
'casual': casual,
|
||||
'seq_par': seq_par,
|
||||
}) for mode in ['fwd', 'bwd'] for casual in [True, False] for seq_par in [True, False]
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, casual, seq_par, provider, dtype=torch.float16, device="cuda"):
|
||||
assert mode in ['fwd', 'bwd']
|
||||
warmup = 25
|
||||
rep = 100
|
||||
sm_scale = 1.3
|
||||
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
if provider == "triton":
|
||||
fn = lambda: triton.ops.attention(q, k, v, casual, sm_scale, seq_par)
|
||||
if mode == 'bwd':
|
||||
o = fn()
|
||||
do = torch.randn_like(o)
|
||||
fn = lambda: o.backward(do, retain_graph=True)
|
||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
if provider == "flash":
|
||||
lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device)
|
||||
cu_seqlens = torch.zeros((BATCH + 1, ), device=device, dtype=torch.int32)
|
||||
cu_seqlens[1:] = lengths.cumsum(0)
|
||||
fn = lambda: flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=sm_scale, causal=casual)
|
||||
if mode == 'bwd':
|
||||
o = fn()
|
||||
do = torch.randn_like(o)
|
||||
fn = lambda: o.backward(do, retain_graph=True)
|
||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
|
||||
|
||||
# only works on post-Ampere GPUs right now
|
||||
# bench_flash_attention.run(save_path='.', print_data=True)
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
|
||||
@@ -8,7 +8,8 @@ import triton.language as tl
|
||||
def test_normalization_with_remat():
|
||||
|
||||
@triton.jit
|
||||
def triton_(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):
|
||||
def triton_(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, xnumel, rnumel, XBLOCK: tl.constexpr,
|
||||
RBLOCK: tl.constexpr):
|
||||
xnumel = 512
|
||||
rnumel = 4096
|
||||
xoffset = tl.program_id(0) * XBLOCK
|
||||
@@ -52,7 +53,7 @@ def test_normalization_with_remat():
|
||||
arg115_1 = torch.rand(64, device="cuda")
|
||||
arg8_1 = torch.rand(64, device="cuda")
|
||||
arg9_1 = torch.rand(64, device="cuda")
|
||||
triton_[(512,)](buf14, buf16, arg114_1, arg115_1, arg8_1, arg9_1, 512, 4096, 1, 2048)
|
||||
triton_[(512, )](buf14, buf16, arg114_1, arg115_1, arg8_1, arg9_1, 512, 4096, 1, 2048)
|
||||
torch.testing.assert_close(buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0)
|
||||
|
||||
|
||||
@@ -148,7 +149,7 @@ def test_avg_pool_bw():
|
||||
inp = torch.ones(8, 2048, 8, 8, device="cuda", dtype=torch.half)
|
||||
out = torch.ones_like(inp) * 3
|
||||
numel = inp.numel()
|
||||
triton_[(numel // 1024,)](inp, out, 1024)
|
||||
triton_[(numel // 1024, )](inp, out, 1024)
|
||||
out_ref = torch.ones_like(inp)
|
||||
out_ref[:, :, 1:7, 0::7] = 2 / 3
|
||||
out_ref[:, :, 0::7, 1:7] = 2 / 3
|
||||
@@ -159,6 +160,7 @@ def test_avg_pool_bw():
|
||||
@pytest.mark.parametrize("RBLOCK", [1, 16, 32, 64, 128])
|
||||
@pytest.mark.parametrize("num_warps", [1, 4])
|
||||
def test_scan2d_broadcast(RBLOCK, num_warps):
|
||||
|
||||
@triton.jit(debug=True)
|
||||
def fn(in_ptr, out_ptr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):
|
||||
rindex = tl.arange(0, RBLOCK)[None, :]
|
||||
@@ -172,12 +174,13 @@ def test_scan2d_broadcast(RBLOCK, num_warps):
|
||||
XBLOCK = 4
|
||||
input = torch.randint(0, 10, (1, RBLOCK), dtype=torch.int64, device='cuda')
|
||||
output = torch.empty((XBLOCK, RBLOCK), dtype=torch.int64, device='cuda')
|
||||
fn[(1,)](input, output, XBLOCK, RBLOCK, num_warps=num_warps)
|
||||
fn[(1, )](input, output, XBLOCK, RBLOCK, num_warps=num_warps)
|
||||
ref = input.cumsum(1).broadcast_to((XBLOCK, RBLOCK))
|
||||
torch.testing.assert_close(output, ref)
|
||||
|
||||
|
||||
def test_scan2d_for():
|
||||
|
||||
@triton.jit
|
||||
def fn(out_ptr0, rnumel, RBLOCK: tl.constexpr):
|
||||
rbase = tl.arange(0, RBLOCK)[None, :]
|
||||
@@ -190,6 +193,6 @@ def test_scan2d_for():
|
||||
|
||||
RBLOCK = 8
|
||||
out0 = torch.empty(RBLOCK, device="cuda", dtype=torch.int64)
|
||||
fn[(1,)](out0, RBLOCK, RBLOCK)
|
||||
fn[(1, )](out0, RBLOCK, RBLOCK)
|
||||
ref = torch.arange(RBLOCK, device="cuda", dtype=torch.int64) + 1
|
||||
torch.testing.assert_close(out0, ref)
|
||||
|
||||
@@ -19,7 +19,7 @@ def f8_to_f16(x, dtype):
|
||||
tl.store(Y + offs, x, mask=mask)
|
||||
|
||||
ret = torch.empty(x.shape, dtype=torch.float16, device=x.device)
|
||||
grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']),)
|
||||
grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']), )
|
||||
dtype = getattr(tl, dtype)
|
||||
kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024)
|
||||
return ret
|
||||
@@ -28,87 +28,88 @@ def f8_to_f16(x, dtype):
|
||||
@pytest.mark.parametrize(
|
||||
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM",
|
||||
itertools.chain(
|
||||
*[
|
||||
[
|
||||
# 1 warp
|
||||
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
# 2 warp
|
||||
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
# 4 warp
|
||||
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
# 8 warp
|
||||
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
# variable input
|
||||
(128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True, True),
|
||||
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]
|
||||
],
|
||||
*[[
|
||||
# 1 warp
|
||||
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
# 2 warp
|
||||
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
# 4 warp
|
||||
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
# 8 warp
|
||||
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
# variable input
|
||||
(128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True, True),
|
||||
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]],
|
||||
# n-stage
|
||||
*[
|
||||
[
|
||||
(16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True, True),
|
||||
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [4]
|
||||
],
|
||||
*[[
|
||||
(16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True, True),
|
||||
]
|
||||
for DTYPE in ["float16", "bfloat16", "float32"]
|
||||
for AT in [False, True]
|
||||
for BT in [False, True]
|
||||
for STAGES in [4]],
|
||||
# mixed-precision
|
||||
*[
|
||||
[
|
||||
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM),
|
||||
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM),
|
||||
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM),
|
||||
] for ADTYPE, BDTYPE in [("float8e4nv", "float8e5"),
|
||||
("float8e4nv", "float8e4nv"),
|
||||
("float8e5", "float8e4nv"),
|
||||
("float8e5", "float8e5"),
|
||||
("float8e4b15", "float8e4b15"),
|
||||
("float8e4nv", "float16"),
|
||||
("float16", "float8e5"),
|
||||
("float16", "float32"),
|
||||
("float32", "float16"),
|
||||
("bfloat16", "float32"),
|
||||
("float32", "bfloat16")] for AT in [False, True] for BT in [False, True] for FASTACCUM in [True, False]
|
||||
],
|
||||
*[[
|
||||
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM),
|
||||
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM),
|
||||
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM),
|
||||
] for ADTYPE, BDTYPE in [
|
||||
("float8e4nv", "float8e5"),
|
||||
("float8e4nv", "float8e4nv"),
|
||||
("float8e5", "float8e4nv"),
|
||||
("float8e5", "float8e5"),
|
||||
("float8e4b15", "float8e4b15"),
|
||||
("float8e4nv", "float16"),
|
||||
("float16", "float8e5"),
|
||||
("float16", "float32"),
|
||||
("float32", "float16"),
|
||||
("bfloat16", "float32"),
|
||||
("float32", "bfloat16"),
|
||||
] for AT in [False, True] for BT in [False, True] for FASTACCUM in [True, False]],
|
||||
# mixed-precision block layout
|
||||
*[
|
||||
[
|
||||
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True),
|
||||
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True),
|
||||
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False, True),
|
||||
] for ADTYPE, BDTYPE in [("float8e4nv", "float16"),
|
||||
("float16", "float8e5"),
|
||||
("float16", "float32"),
|
||||
("float32", "float16"),
|
||||
("bfloat16", "float32"),
|
||||
("float32", "bfloat16")] for AT in [False, True] for BT in [False, True]
|
||||
],
|
||||
*[[
|
||||
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True),
|
||||
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True),
|
||||
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False, True),
|
||||
] for ADTYPE, BDTYPE in [
|
||||
("float8e4nv", "float16"),
|
||||
("float16", "float8e5"),
|
||||
("float16", "float32"),
|
||||
("float32", "float16"),
|
||||
("bfloat16", "float32"),
|
||||
("float32", "bfloat16"),
|
||||
] for AT in [False, True] for BT in [False, True]],
|
||||
),
|
||||
)
|
||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM):
|
||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32,
|
||||
F8_FASTACCUM):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 7:
|
||||
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||
@@ -147,7 +148,7 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
|
||||
return torch.randint(-128, 127, (m, n), device="cuda", dtype=torch.int8)
|
||||
dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[dtype]
|
||||
exponents = torch.randint(-10, 0, size=(m, n))
|
||||
ret = (2. ** exponents).to(dtype).to("cuda")
|
||||
ret = (2.**exponents).to(dtype).to("cuda")
|
||||
return ret
|
||||
|
||||
# allocate/transpose inputs
|
||||
|
||||
@@ -17,6 +17,25 @@ def test_kwargs():
|
||||
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
x = tl.load(src + offsets, mask=offsets < N)
|
||||
tl.store(dst + offsets, x, mask=offsets < N)
|
||||
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']),)
|
||||
|
||||
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), )
|
||||
_kernel[grid](dst, src, N)
|
||||
_kernel[grid](dst=dst, src=src, N=N)
|
||||
|
||||
|
||||
def test_restore():
|
||||
N = 1024
|
||||
src = torch.zeros(N, device='cuda')
|
||||
|
||||
configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})]
|
||||
|
||||
@triton.autotune(configs=configs, key=['N'], restore_value=['src'])
|
||||
@triton.jit
|
||||
def _kernel(src, N, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
x = tl.load(src + offsets, mask=offsets < N) + 1
|
||||
tl.store(src + offsets, x, mask=offsets < N)
|
||||
|
||||
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), )
|
||||
_kernel[grid](src, N)
|
||||
triton.testing.assert_close(src, torch.ones_like(src))
|
||||
|
||||
@@ -80,11 +80,12 @@ def test_reuse():
|
||||
def inc_counter(*args, **kwargs):
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
|
||||
JITFunction.cache_hook = inc_counter
|
||||
reset_tmp_dir()
|
||||
x = torch.empty(1, dtype=torch.int32, device='cuda')
|
||||
for i in range(10):
|
||||
kernel[(1,)](x, 1, BLOCK=1024)
|
||||
kernel[(1, )](x, 1, BLOCK=1024)
|
||||
assert counter == 1
|
||||
|
||||
|
||||
@@ -95,17 +96,19 @@ def test_specialize(mode):
|
||||
def inc_counter(*args, **kwargs):
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
|
||||
JITFunction.cache_hook = inc_counter
|
||||
reset_tmp_dir()
|
||||
x = torch.empty(1, dtype=torch.int32, device='cuda')
|
||||
function = {'enable': kernel, 'disable': kernel_nospec}[mode]
|
||||
target = {'enable': 4, 'disable': 1}[mode]
|
||||
for i in [1, 2, 4, 8, 16, 32]:
|
||||
function[(1,)](x, i, BLOCK=512)
|
||||
function[(1, )](x, i, BLOCK=512)
|
||||
assert counter == target
|
||||
|
||||
|
||||
def test_annotation():
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, i: tl.int32):
|
||||
tl.store(X, i)
|
||||
@@ -113,14 +116,15 @@ def test_annotation():
|
||||
x = torch.empty(1, dtype=torch.int32, device='cuda')
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
kernel[(1,)](x, 1)
|
||||
kernel[(1,)](x, 8)
|
||||
kernel[(1,)](x, 16)
|
||||
kernel[(1,)](x, 17)
|
||||
kernel[(1, )](x, 1)
|
||||
kernel[(1, )](x, 8)
|
||||
kernel[(1, )](x, 16)
|
||||
kernel[(1, )](x, 17)
|
||||
assert len(kernel.cache[device]) == 4
|
||||
|
||||
|
||||
def test_constexpr_not_callable() -> None:
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, c: tl.constexpr):
|
||||
tl.store(X, 2)
|
||||
@@ -141,11 +145,11 @@ def test_constexpr_not_callable() -> None:
|
||||
|
||||
|
||||
def test_jit_warmup_cache() -> None:
|
||||
|
||||
@triton.jit
|
||||
def kernel_add(a, b, o, N: tl.constexpr):
|
||||
idx = tl.arange(0, N)
|
||||
tl.store(o + idx,
|
||||
tl.load(a + idx) + tl.load(b + idx))
|
||||
tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx))
|
||||
|
||||
args = [
|
||||
torch.randn(32, dtype=torch.float32, device="cuda"),
|
||||
@@ -155,31 +159,31 @@ def test_jit_warmup_cache() -> None:
|
||||
]
|
||||
device = torch.cuda.current_device()
|
||||
assert len(kernel_add.cache[device]) == 0
|
||||
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
||||
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
|
||||
assert len(kernel_add.cache[device]) == 1
|
||||
kernel_add.warmup(*args, grid=(1,))
|
||||
kernel_add.warmup(*args, grid=(1, ))
|
||||
assert len(kernel_add.cache[device]) == 1
|
||||
kernel_add.warmup(*args, grid=(1,))
|
||||
kernel_add.warmup(*args, grid=(1, ))
|
||||
assert len(kernel_add.cache[device]) == 1
|
||||
|
||||
|
||||
def test_jit_debug() -> None:
|
||||
|
||||
@triton.jit
|
||||
def kernel_add(a, b, o, N: tl.constexpr):
|
||||
idx = tl.arange(0, N)
|
||||
tl.device_assert(idx < 32, "idx < 32")
|
||||
tl.store(o + idx,
|
||||
tl.load(a + idx) + tl.load(b + idx))
|
||||
tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx))
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
assert len(kernel_add.cache[device]) == 0
|
||||
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
||||
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
|
||||
assert len(kernel_add.cache[device]) == 1
|
||||
kernel_add.debug = False
|
||||
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
||||
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
|
||||
assert len(kernel_add.cache[device]) == 2
|
||||
kernel_add.debug = True
|
||||
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
||||
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
|
||||
assert len(kernel_add.cache[device]) == 3
|
||||
bins = list(kernel_add.cache[device].values())
|
||||
assert bins[2].asm['ttir'] != bins[1].asm['ttir']
|
||||
@@ -192,13 +196,14 @@ def add_fn(a, b, o, N: tl.constexpr):
|
||||
|
||||
|
||||
def test_jit_noinline() -> None:
|
||||
|
||||
@triton.jit
|
||||
def kernel_add_device(a, b, o, N: tl.constexpr):
|
||||
add_fn(a, b, o, N)
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
assert len(kernel_add_device.cache[device]) == 0
|
||||
kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
||||
kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
|
||||
assert len(kernel_add_device.cache[device]) == 1
|
||||
bins = list(kernel_add_device.cache[device].values())
|
||||
inline_ttir = bins[0].asm['ttir']
|
||||
@@ -206,7 +211,7 @@ def test_jit_noinline() -> None:
|
||||
add_fn.hash = None
|
||||
kernel_add_device.hash = None
|
||||
kernel_add_device.cache[device].clear()
|
||||
kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
||||
kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
|
||||
assert len(kernel_add_device.cache[device]) == 1
|
||||
bins = list(kernel_add_device.cache[device].values())
|
||||
noinline_ttir = bins[0].asm['ttir']
|
||||
@@ -214,6 +219,7 @@ def test_jit_noinline() -> None:
|
||||
|
||||
|
||||
def test_memory_leak() -> None:
|
||||
|
||||
@triton.jit
|
||||
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):
|
||||
xnumel = 10
|
||||
|
||||
@@ -31,11 +31,11 @@ def test_memory_leak() -> None:
|
||||
try:
|
||||
inp = torch.randn(10, device='cuda')
|
||||
out = torch.randn(10, device='cuda')
|
||||
kernel[(10,)](inp, out, 10, XBLOCK=16)
|
||||
kernel[(10, )](inp, out, 10, XBLOCK=16)
|
||||
gc.collect()
|
||||
begin, _ = tracemalloc.get_traced_memory()
|
||||
for _ in range(100):
|
||||
kernel[(10,)](inp, out, 10, XBLOCK=16)
|
||||
kernel[(10, )](inp, out, 10, XBLOCK=16)
|
||||
gc.collect()
|
||||
end, _ = tracemalloc.get_traced_memory()
|
||||
assert end - begin < 30000
|
||||
|
||||
@@ -17,9 +17,11 @@ def reset_tmp_dir():
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
|
||||
|
||||
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])
|
||||
instance_descriptor = namedtuple("instance_descriptor",
|
||||
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
def get_device_type():
|
||||
try:
|
||||
import torch
|
||||
@@ -36,10 +38,15 @@ def get_device_type():
|
||||
|
||||
|
||||
def compile_fn(config, device_type, cc):
|
||||
=======
|
||||
def compile_fn(config, cc):
|
||||
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
@triton.jit
|
||||
def kernel_sub(a, b, o, N: tl.constexpr):
|
||||
idx = tl.arange(0, N)
|
||||
tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777)
|
||||
|
||||
triton.compile(
|
||||
fn=kernel_sub,
|
||||
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
|
||||
@@ -57,15 +64,24 @@ def test_compile_in_subproc() -> None:
|
||||
config = instance_descriptor(tuple(range(4)), (), (), ())
|
||||
|
||||
multiprocessing.set_start_method('fork')
|
||||
<<<<<<< HEAD
|
||||
proc = multiprocessing.Process(
|
||||
target=compile_fn,
|
||||
args=(config, device_type, cc))
|
||||
=======
|
||||
proc = multiprocessing.Process(target=compile_fn, args=(config, cc))
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
proc.start()
|
||||
proc.join()
|
||||
assert proc.exitcode == 0
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
def compile_fn_dot(config, device_type, cc):
|
||||
=======
|
||||
def compile_fn_dot(config, cc):
|
||||
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
@triton.jit
|
||||
def kernel_dot(Z):
|
||||
offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :]
|
||||
@@ -90,9 +106,13 @@ def test_compile_in_forked_subproc() -> None:
|
||||
config = instance_descriptor(tuple(range(1)), (), (), ())
|
||||
|
||||
assert multiprocessing.get_start_method() == 'fork'
|
||||
<<<<<<< HEAD
|
||||
proc = multiprocessing.Process(
|
||||
target=compile_fn_dot,
|
||||
args=(config, device_type, cc))
|
||||
=======
|
||||
proc = multiprocessing.Process(target=compile_fn_dot, args=(config, cc))
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
proc.start()
|
||||
proc.join()
|
||||
assert proc.exitcode == 0
|
||||
|
||||
@@ -59,7 +59,7 @@ def kernel(C, A, B, M, N, K,
|
||||
tl.store(c_ptrs, c)
|
||||
"""
|
||||
|
||||
test_utils_src = '''
|
||||
test_utils_src = """
|
||||
#include <cuda.h>
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
@@ -93,23 +93,26 @@ static void read_csv_to_buffer(char *filename, int16_t *buffer, int size) {
|
||||
index++;
|
||||
}
|
||||
fclose(file);
|
||||
}'''
|
||||
}"""
|
||||
|
||||
|
||||
def gen_kernel_library(dir, libname):
|
||||
c_files = glob.glob(os.path.join(dir, "*.c"))
|
||||
subprocess.run(["gcc"] + c_files + ["-I", cuda_include_dir(),
|
||||
"-c", "-fPIC"],
|
||||
check=True, cwd=dir)
|
||||
subprocess.run(
|
||||
["gcc"] + c_files + ["-I", cuda_include_dir(), "-c", "-fPIC"],
|
||||
check=True,
|
||||
cwd=dir,
|
||||
)
|
||||
o_files = glob.glob(os.path.join(dir, "*.o"))
|
||||
subprocess.run(["gcc"] + o_files + ["-shared",
|
||||
"-o", libname,
|
||||
"-L", libcuda_dirs()[0]],
|
||||
check=True, cwd=dir)
|
||||
subprocess.run(
|
||||
["gcc"] + o_files + ["-shared", "-o", libname, "-L", libcuda_dirs()[0]],
|
||||
check=True,
|
||||
cwd=dir,
|
||||
)
|
||||
|
||||
|
||||
def gen_test_bin(dir, M, N, K, exe="test", algo_id=0):
|
||||
test_src = f'''
|
||||
test_src = f"""
|
||||
int main(int argc, char **argv) {{
|
||||
int M = {M}, N = {N}, K = {K};
|
||||
|
||||
@@ -165,17 +168,29 @@ int main(int argc, char **argv) {{
|
||||
cuMemFree(C);
|
||||
cuCtxDestroy(ctx);
|
||||
}}
|
||||
'''
|
||||
"""
|
||||
src = test_utils_src + test_src
|
||||
with open(os.path.join(dir, "test.c"), "w") as file:
|
||||
file.write(src)
|
||||
subprocess.run(["gcc"] + ["test.c",
|
||||
"-I", cuda_include_dir(),
|
||||
"-L", libcuda_dirs()[0],
|
||||
"-l", "cuda",
|
||||
"-L", dir,
|
||||
"-l", "kernel",
|
||||
"-o", exe], check=True, cwd=dir)
|
||||
subprocess.run(
|
||||
["gcc"] + [
|
||||
"test.c",
|
||||
"-I",
|
||||
cuda_include_dir(),
|
||||
"-L",
|
||||
libcuda_dirs()[0],
|
||||
"-l",
|
||||
"cuda",
|
||||
"-L",
|
||||
dir,
|
||||
"-l",
|
||||
"kernel",
|
||||
"-o",
|
||||
exe,
|
||||
],
|
||||
check=True,
|
||||
cwd=dir,
|
||||
)
|
||||
|
||||
|
||||
def write_triton_kernels(dir, src, util_src):
|
||||
@@ -190,16 +205,67 @@ def write_triton_kernels(dir, src, util_src):
|
||||
return kernel_path
|
||||
|
||||
|
||||
def compile_aot_kernels(dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints):
|
||||
def _compile_kernel(dir, signature, kernel_name, out_name, out_path, num_warps, grid, kernel_path):
|
||||
compiler_path = os.path.join(triton.tools.__path__[0], "compile.py")
|
||||
|
||||
subprocess.run(
|
||||
[
|
||||
sys.executable,
|
||||
compiler_path,
|
||||
"-n",
|
||||
kernel_name,
|
||||
"--signature",
|
||||
signature,
|
||||
"--out-name",
|
||||
out_name,
|
||||
"-o",
|
||||
out_path,
|
||||
"-w",
|
||||
str(num_warps),
|
||||
"-g",
|
||||
grid,
|
||||
kernel_path,
|
||||
],
|
||||
check=True,
|
||||
cwd=dir,
|
||||
)
|
||||
|
||||
|
||||
# Edge case kernel with no specialization
|
||||
def compile_aot_kernel_no_specialization(dir, kernel_path, dtype, BM, BN, BK):
|
||||
# compile all desired configs
|
||||
sig = f"*fp32, *{dtype}, *{dtype}, i32, i32, i32, i32, i32, i32, i32, i32, i32, {BM}, {BN}, {BK}"
|
||||
name = f"matmul_{dtype}"
|
||||
grid = f"M/{BM}, N/{BN}, 1"
|
||||
_compile_kernel(
|
||||
dir=dir,
|
||||
signature=sig,
|
||||
kernel_name="kernel",
|
||||
out_name=name,
|
||||
out_path=name,
|
||||
num_warps=1,
|
||||
grid=grid,
|
||||
kernel_path=kernel_path,
|
||||
)
|
||||
|
||||
|
||||
def compile_aot_kernels(dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints):
|
||||
# compile all desired configs
|
||||
for ha in ha_hb_hints:
|
||||
for hb in ha_hb_hints:
|
||||
sig = f'*fp32:16, *{dtype}:16, *{dtype}:16, i32, i32, i32, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}'
|
||||
sig = f"*fp32:16, *{dtype}:16, *{dtype}:16, i32, i32, i32, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}"
|
||||
name = f"matmul_{dtype}"
|
||||
grid = f'M/{BM}, N/{BN}, 1'
|
||||
subprocess.run([sys.executable, compiler_path, "-n", "kernel", "--signature", sig, "--out-name", name, "-o", name, "-w", "1", "-g", grid, kernel_path], check=True, cwd=dir)
|
||||
grid = f"M/{BM}, N/{BN}, 1"
|
||||
_compile_kernel(
|
||||
dir=dir,
|
||||
signature=sig,
|
||||
kernel_name="kernel",
|
||||
out_name=name,
|
||||
out_path=name,
|
||||
num_warps=1,
|
||||
grid=grid,
|
||||
kernel_path=kernel_path,
|
||||
)
|
||||
|
||||
|
||||
def link_aot_kernels(dir):
|
||||
@@ -221,11 +287,42 @@ def generate_matmul_test_data(dir, M, N, K):
|
||||
return a, b, a_path, b_path, c_path
|
||||
|
||||
|
||||
# Test edge case where the provided kernel signature has no specializations
|
||||
def test_compile_link_matmul_no_specialization():
|
||||
np.random.seed(3)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
dtype = "fp16"
|
||||
BM, BN, BK = 16, 16, 16
|
||||
|
||||
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
|
||||
compile_aot_kernel_no_specialization(tmp_dir, kernel_path, dtype, BM, BN, BK)
|
||||
link_aot_kernels(tmp_dir)
|
||||
|
||||
# compile test case
|
||||
M, N, K = 16, 16, 16
|
||||
gen_kernel_library(tmp_dir, "libkernel.so")
|
||||
gen_test_bin(tmp_dir, M, N, K)
|
||||
|
||||
# initialize test data
|
||||
a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K)
|
||||
|
||||
# run test case
|
||||
env = os.environ.copy()
|
||||
env["LD_LIBRARY_PATH"] = tmp_dir
|
||||
subprocess.run(["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir)
|
||||
|
||||
# read data and compare against reference
|
||||
c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32)
|
||||
c_tri = c.reshape((M, N)).view(np.float32)
|
||||
c_ref = np.matmul(a.astype(np.float32), b.astype(np.float32))
|
||||
np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=0.0)
|
||||
|
||||
|
||||
def test_compile_link_matmul():
|
||||
np.random.seed(3)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
||||
dtype = "fp16"
|
||||
BM, BN, BK = 16, 16, 16
|
||||
|
||||
@@ -250,7 +347,7 @@ def test_compile_link_matmul():
|
||||
c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32)
|
||||
c_tri = c.reshape((M, N)).view(np.float32)
|
||||
c_ref = np.matmul(a.astype(np.float32), b.astype(np.float32))
|
||||
np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=0.)
|
||||
np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=0.0)
|
||||
|
||||
|
||||
def test_launcher_has_no_available_kernel():
|
||||
@@ -275,7 +372,13 @@ def test_launcher_has_no_available_kernel():
|
||||
# run test case
|
||||
env = os.environ.copy()
|
||||
env["LD_LIBRARY_PATH"] = tmp_dir
|
||||
result = subprocess.run(["./test", a_path, b_path, c_path], env=env, cwd=tmp_dir, capture_output=True, text=True)
|
||||
result = subprocess.run(
|
||||
["./test", a_path, b_path, c_path],
|
||||
env=env,
|
||||
cwd=tmp_dir,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
# It should fail since the launcher requires all the strides be 1 while they are not.
|
||||
assert result.returncode == -6
|
||||
@@ -286,7 +389,6 @@ def test_compile_link_autotune_matmul():
|
||||
np.random.seed(3)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
||||
dtype = "fp16"
|
||||
|
||||
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
|
||||
@@ -319,7 +421,12 @@ def test_compile_link_autotune_matmul():
|
||||
|
||||
env = os.environ.copy()
|
||||
env["LD_LIBRARY_PATH"] = tmp_dir
|
||||
subprocess.run([f"./{test_name}", a_path, b_path, c_path], check=True, cwd=tmp_dir, env=env)
|
||||
subprocess.run(
|
||||
[f"./{test_name}", a_path, b_path, c_path],
|
||||
check=True,
|
||||
cwd=tmp_dir,
|
||||
env=env,
|
||||
)
|
||||
|
||||
# read data and compare against reference
|
||||
c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32)
|
||||
|
||||
@@ -45,12 +45,12 @@ __all__ = [
|
||||
"tools",
|
||||
]
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
# misc. utilities that don't fit well
|
||||
# into any specific module
|
||||
# -------------------------------------
|
||||
|
||||
|
||||
def cdiv(x: int, y: int):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import importlib
|
||||
import importlib.util
|
||||
import os
|
||||
@@ -10,8 +10,12 @@ from typing import Dict
|
||||
|
||||
from ..runtime.driver import DriverBase
|
||||
|
||||
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
TRITON_VERSION = "2.1.0"
|
||||
|
||||
|
||||
class BaseBackend:
|
||||
|
||||
def __init__(self, device_type: str) -> None:
|
||||
self.device_type = device_type
|
||||
|
||||
@@ -104,7 +108,7 @@ def get_backend(device_type: str):
|
||||
def _path_to_binary(binary: str):
|
||||
base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
|
||||
paths = [
|
||||
os.environ.get("TRITON_PTXAS_PATH", ""),
|
||||
os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
|
||||
os.path.join(base_dir, "third_party", "cuda", "bin", binary)
|
||||
]
|
||||
|
||||
@@ -132,3 +136,48 @@ def path_to_cuobjdump():
|
||||
@functools.lru_cache()
|
||||
def path_to_nvdisasm():
|
||||
return _path_to_binary("nvdisasm")
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def compute_core_version_key():
|
||||
import pkgutil
|
||||
contents = []
|
||||
# frontend
|
||||
with open(__file__, "rb") as f:
|
||||
contents += [hashlib.sha1(f.read()).hexdigest()]
|
||||
# compiler
|
||||
compiler_path = os.path.join(TRITON_PATH, 'compiler')
|
||||
for lib in pkgutil.iter_modules([compiler_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.sha1(f.read()).hexdigest()]
|
||||
# backend
|
||||
libtriton_hash = hashlib.sha1()
|
||||
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
|
||||
while True:
|
||||
chunk = f.read(1024**2)
|
||||
if not chunk:
|
||||
break
|
||||
libtriton_hash.update(chunk)
|
||||
contents.append(libtriton_hash.hexdigest())
|
||||
# language
|
||||
language_path = os.path.join(TRITON_PATH, 'language')
|
||||
for lib in pkgutil.iter_modules([language_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.sha1(f.read()).hexdigest()]
|
||||
return '-'.join(TRITON_VERSION) + '-'.join(contents)
|
||||
|
||||
|
||||
_cached_cuda_version_key = None
|
||||
|
||||
|
||||
def get_cuda_version_key():
|
||||
global _cached_cuda_version_key
|
||||
if _cached_cuda_version_key is None:
|
||||
key = compute_core_version_key()
|
||||
try:
|
||||
ptxas = path_to_ptxas()[0]
|
||||
ptxas_version = subprocess.check_output([ptxas, "--version"])
|
||||
except RuntimeError:
|
||||
ptxas_version = b"NO_PTXAS"
|
||||
_cached_cuda_version_key = key + '-' + hashlib.sha1(ptxas_version).hexdigest()
|
||||
return _cached_cuda_version_key
|
||||
|
||||
@@ -92,9 +92,15 @@ def _build(name, src, srcdir):
|
||||
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
|
||||
|
||||
if is_hip():
|
||||
ret = subprocess.check_call([cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{hip_lib_dir}", "-lamdhip64", "-o", so])
|
||||
ret = subprocess.check_call([
|
||||
cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC",
|
||||
f"-L{hip_lib_dir}", "-lamdhip64", "-o", so
|
||||
])
|
||||
else:
|
||||
cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]
|
||||
cc_cmd = [
|
||||
cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda",
|
||||
"-o", so
|
||||
]
|
||||
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
|
||||
ret = subprocess.check_call(cc_cmd)
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from .compiler import (CompiledKernel, compile, get_arch_default_num_stages,
|
||||
get_arch_default_num_warps, instance_descriptor)
|
||||
from .compiler import (CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps,
|
||||
instance_descriptor)
|
||||
from .errors import CompilationError
|
||||
|
||||
__all__ = ["compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps", "get_arch_default_num_stages"]
|
||||
__all__ = [
|
||||
"compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps",
|
||||
"get_arch_default_num_stages"
|
||||
]
|
||||
|
||||
@@ -10,8 +10,7 @@ from .._C.libtriton.triton import ir
|
||||
from ..language import constexpr, tensor
|
||||
# ideally we wouldn't need any runtime component
|
||||
from ..runtime import JITFunction
|
||||
from .errors import (CompilationError, CompileTimeAssertionFailure,
|
||||
UnsupportedLanguageConstruct)
|
||||
from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
|
||||
|
||||
|
||||
def mangle_ty(ty):
|
||||
@@ -68,7 +67,10 @@ def _check_fn_args(node, fn, args):
|
||||
if fn.noinline:
|
||||
for idx, arg in enumerate(args):
|
||||
if not _is_constexpr(arg) and not _is_triton_scalar(arg):
|
||||
raise UnsupportedLanguageConstruct(fn.src, node, f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}')
|
||||
raise UnsupportedLanguageConstruct(
|
||||
fn.src, node,
|
||||
f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}'
|
||||
)
|
||||
|
||||
|
||||
def _get_fn_file_line(fn):
|
||||
@@ -89,6 +91,7 @@ _condition_types = {bool, int, type(None)} # Python types accepted for conditio
|
||||
|
||||
|
||||
class enter_sub_region:
|
||||
|
||||
def __init__(self, generator):
|
||||
self.generator = generator
|
||||
|
||||
@@ -109,6 +112,7 @@ class enter_sub_region:
|
||||
|
||||
# Check if the given syntax node has an "early" return
|
||||
class ContainsReturnChecker(ast.NodeVisitor):
|
||||
|
||||
def __init__(self, gscope):
|
||||
self.gscope = gscope
|
||||
|
||||
@@ -199,9 +203,10 @@ class ContainsReturnChecker(ast.NodeVisitor):
|
||||
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, function_name, target,
|
||||
module=None, is_kernel=False, function_types: Optional[Dict] = None,
|
||||
debug=False, noinline=False, file_name: Optional[str] = None, begin_line=0):
|
||||
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, function_name, target, module=None,
|
||||
is_kernel=False, function_types: Optional[Dict] = None, debug=False, noinline=False,
|
||||
file_name: Optional[str] = None, begin_line=0):
|
||||
self.context = context
|
||||
self.builder = ir.builder(context)
|
||||
self.file_name = file_name
|
||||
@@ -237,8 +242,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
))
|
||||
|
||||
def _define_name_lookup(self):
|
||||
|
||||
def local_lookup(name: str, absent):
|
||||
value = self.lscope.get(name, absent) # this needs to be re-fetched from `self` every time, because it gets switched occasionally
|
||||
# this needs to be re-fetched from `self` every time, because it gets switched occasionally
|
||||
value = self.lscope.get(name, absent)
|
||||
if value is not absent and name not in self.local_defs:
|
||||
self.global_uses[name] = value
|
||||
return value
|
||||
@@ -255,8 +262,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
return name_lookup
|
||||
|
||||
def set_value(self, name: str,
|
||||
value: Union[tensor, constexpr]) -> None:
|
||||
def set_value(self, name: str, value: Union[tensor, constexpr]) -> None:
|
||||
''' This function:
|
||||
called by visit_Assign() & visit_FunctionDef() to store left value (lvalue)
|
||||
1. record local defined name (FIXME: should consider control flow)
|
||||
@@ -338,7 +344,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.visit(init_node)
|
||||
# initialize function
|
||||
visibility = "public" if self.is_kernel else "private"
|
||||
self.fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility, self.noinline)
|
||||
self.fn = self.builder.get_or_insert_function(self.module, self.function_name,
|
||||
self.prototype.to_ir(self.builder), visibility, self.noinline)
|
||||
self.module.push_back(self.fn)
|
||||
entry = self.fn.add_entry_block()
|
||||
arg_values = []
|
||||
@@ -469,12 +476,23 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
rhs = self.visit(node.right)
|
||||
method_name = self._method_name_for_bin_op.get(type(node.op))
|
||||
if method_name is None:
|
||||
raise UnsupportedLanguageConstruct(None, node, "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
||||
return self._apply_binary_method(method_name, lhs, rhs)
|
||||
|
||||
_method_name_for_bin_op: Dict[Type[ast.operator], str] = {
|
||||
ast.Add: '__add__', ast.Sub: '__sub__', ast.Mult: '__mul__', ast.Div: '__truediv__',
|
||||
ast.FloorDiv: '__floordiv__', ast.Mod: '__mod__', ast.Pow: '__pow__',
|
||||
ast.LShift: '__lshift__', ast.RShift: '__rshift__', ast.BitAnd: '__and__', ast.BitOr: '__or__', ast.BitXor: '__xor__',
|
||||
ast.Add: '__add__',
|
||||
ast.Sub: '__sub__',
|
||||
ast.Mult: '__mul__',
|
||||
ast.Div: '__truediv__',
|
||||
ast.FloorDiv: '__floordiv__',
|
||||
ast.Mod: '__mod__',
|
||||
ast.Pow: '__pow__',
|
||||
ast.LShift: '__lshift__',
|
||||
ast.RShift: '__rshift__',
|
||||
ast.BitAnd: '__and__',
|
||||
ast.BitOr: '__or__',
|
||||
ast.BitXor: '__xor__',
|
||||
}
|
||||
|
||||
def visit_then_else_blocks(self, node, liveins, then_block, else_block):
|
||||
@@ -508,7 +526,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if name in then_defs or name in else_defs:
|
||||
names.append(name)
|
||||
ret_types.append(then_defs[name].type if name in then_defs else else_defs[name].type)
|
||||
ir_ret_types.append(then_defs[name].handle.get_type() if name in then_defs else else_defs[name].handle.get_type())
|
||||
ir_ret_types.append(then_defs[name].handle.get_type() if name in
|
||||
then_defs else else_defs[name].handle.get_type())
|
||||
# variable defined in then but not in else
|
||||
if name in then_defs and name not in else_defs:
|
||||
else_defs[name] = liveins[name]
|
||||
@@ -602,8 +621,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
contains_return = ContainsReturnChecker(self.gscope).visit(node)
|
||||
if self.scf_stack and contains_return:
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node,
|
||||
"Cannot have `return` statements inside `while` or `for` statements in triton "
|
||||
None, node, "Cannot have `return` statements inside `while` or `for` statements in triton "
|
||||
"(note that this also applies to `return` statements that are inside functions "
|
||||
"transitively called from within `while`/`for` statements)")
|
||||
elif self.scf_stack or not contains_return:
|
||||
@@ -612,10 +630,13 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.visit_if_top_level(cond, node)
|
||||
else:
|
||||
cond = _unwrap_if_constexpr(cond)
|
||||
if type(cond) not in _condition_types: # not isinstance - we insist the real thing, no subclasses and no ducks
|
||||
# not isinstance - we insist the real thing, no subclasses and no ducks
|
||||
if type(cond) not in _condition_types:
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
|
||||
', '.join(_.__name__ for _ in _condition_types), type(cond).__name__))
|
||||
None, node,
|
||||
"`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
|
||||
', '.join(_.__name__ for _ in _condition_types),
|
||||
type(cond).__name__))
|
||||
if cond:
|
||||
self.visit_compound_statement(node.body)
|
||||
else:
|
||||
@@ -624,15 +645,52 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def visit_IfExp(self, node):
|
||||
cond = self.visit(node.test)
|
||||
if _is_triton_tensor(cond):
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node,
|
||||
"Triton does not support `if` expressions (ternary operators) with dynamic conditions, use `if` statements instead")
|
||||
cond = cond.to(language.int1, _builder=self.builder)
|
||||
# TODO: Deal w/ more complicated return types (e.g tuple)
|
||||
with enter_sub_region(self):
|
||||
ip, last_loc = self._get_insertion_point_and_loc()
|
||||
|
||||
then_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(then_block)
|
||||
then_val = language.core._to_tensor(self.visit(node.body), self.builder)
|
||||
then_block = self.builder.get_insertion_block()
|
||||
|
||||
else_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(else_block)
|
||||
# do not need to reset lscope since
|
||||
# ternary expressions cannot define new variables
|
||||
else_val = language.core._to_tensor(self.visit(node.orelse), self.builder)
|
||||
else_block = self.builder.get_insertion_block()
|
||||
|
||||
self._set_insertion_point_and_loc(ip, last_loc)
|
||||
|
||||
assert then_val.type == else_val.type, \
|
||||
f'ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}'
|
||||
ret_type = then_val.type
|
||||
|
||||
ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else []
|
||||
if_op = self.builder.create_if_op(ret_type_ir, cond.handle, True)
|
||||
then_block.merge_block_before(if_op.get_then_block())
|
||||
if ret_type_ir:
|
||||
self.builder.set_insertion_point_to_end(if_op.get_then_block())
|
||||
self.builder.create_yield_op([then_val.handle])
|
||||
|
||||
self.builder.set_insertion_point_to_end(if_op.get_then_block())
|
||||
else_block.merge_block_before(if_op.get_else_block())
|
||||
if ret_type_ir:
|
||||
self.builder.set_insertion_point_to_end(if_op.get_else_block())
|
||||
self.builder.create_yield_op([else_val.handle])
|
||||
return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None
|
||||
else:
|
||||
cond = _unwrap_if_constexpr(cond)
|
||||
if type(cond) not in _condition_types: # not isinstance - we insist the real thing, no subclasses and no ducks
|
||||
|
||||
# not isinstance - we insist the real thing, no subclasses and no ducks
|
||||
if type(cond) not in _condition_types:
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
|
||||
', '.join(_.__name__ for _ in _condition_types), type(cond).__name__))
|
||||
None, node,
|
||||
"`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
|
||||
', '.join(_.__name__ for _ in _condition_types),
|
||||
type(cond).__name__))
|
||||
if cond:
|
||||
return self.visit(node.body)
|
||||
else:
|
||||
@@ -654,8 +712,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
return constexpr(lhs_value is not rhs_value)
|
||||
method_name = self._method_name_for_comp_op.get(type(node.ops[0]))
|
||||
if method_name is None:
|
||||
raise UnsupportedLanguageConstruct(None, node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__))
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__))
|
||||
return self._apply_binary_method(method_name, lhs, rhs)
|
||||
|
||||
_method_name_for_comp_op: Dict[Type[ast.cmpop], str] = {
|
||||
ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__'
|
||||
}
|
||||
@@ -664,11 +724,15 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
op = self.visit(node.operand)
|
||||
fn = self._method_name_for_unary_op.get(type(node.op))
|
||||
if fn is None:
|
||||
raise UnsupportedLanguageConstruct(None, node, "AST unary operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "AST unary operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
||||
if _is_triton_tensor(op):
|
||||
return getattr(op, fn)(_builder=self.builder)
|
||||
return getattr(op, fn)()
|
||||
_method_name_for_unary_op: Dict[Type[ast.unaryop], str] = {ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__'}
|
||||
|
||||
_method_name_for_unary_op: Dict[Type[ast.unaryop], str] = {
|
||||
ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__'
|
||||
}
|
||||
|
||||
def visit_While(self, node):
|
||||
with enter_sub_region(self) as sr:
|
||||
@@ -763,9 +827,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
iter_args = [self.visit(arg) for arg in node.iter.args]
|
||||
if IteratorClass == language.static_range:
|
||||
iterator = IteratorClass(*iter_args)
|
||||
static_range = range(iterator.start.value,
|
||||
iterator.end.value,
|
||||
iterator.step.value)
|
||||
static_range = range(iterator.start.value, iterator.end.value, iterator.step.value)
|
||||
for i in static_range:
|
||||
self.lscope[node.target.id] = constexpr(i)
|
||||
self.visit_compound_statement(node.body)
|
||||
@@ -902,8 +964,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def call_JitFunction(self, fn: JITFunction, args, kwargs):
|
||||
args = inspect.getcallargs(fn.fn, *args, **kwargs)
|
||||
args = [args[name] for name in fn.arg_names]
|
||||
args = [arg if _is_triton_tensor(arg)
|
||||
else constexpr(arg) for arg in args]
|
||||
args = [arg if _is_triton_tensor(arg) else constexpr(arg) for arg in args]
|
||||
# generate function def
|
||||
attributes = dict()
|
||||
constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)]
|
||||
@@ -921,8 +982,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
debug = self.debug if fn.debug is None else fn.debug
|
||||
file_name, begin_line = _get_fn_file_line(fn)
|
||||
generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module,
|
||||
function_name=fn_name, function_types=self.function_ret_types, debug=debug, noinline=fn.noinline,
|
||||
file_name=file_name, begin_line=begin_line, target=self.builder.target)
|
||||
function_name=fn_name, function_types=self.function_ret_types, debug=debug,
|
||||
noinline=fn.noinline, file_name=file_name, begin_line=begin_line,
|
||||
target=self.builder.target)
|
||||
generator.visit(fn.parse())
|
||||
callee_ret_type = generator.last_ret_type
|
||||
self.function_ret_types[fn_name] = callee_ret_type
|
||||
@@ -950,7 +1012,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
kws = dict(self.visit(keyword) for keyword in node.keywords)
|
||||
args = [self.visit(arg) for arg in node.args]
|
||||
if fn is language.core.device_assert: # TODO: this should not be so hardcoded
|
||||
if fn is language.core.device_assert: # TODO: this should not be so hardcoded
|
||||
if not self.debug:
|
||||
return
|
||||
if isinstance(fn, JITFunction):
|
||||
@@ -971,16 +1033,21 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
def visit_BoolOp(self, node: ast.BoolOp):
|
||||
if len(node.values) != 2:
|
||||
raise UnsupportedLanguageConstruct(None, node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.")
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node,
|
||||
"chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.")
|
||||
lhs = self.visit(node.values[0])
|
||||
rhs = self.visit(node.values[1])
|
||||
method_name = self._method_name_for_bool_op.get(type(node.op))
|
||||
if method_name is None:
|
||||
raise UnsupportedLanguageConstruct(None, node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
||||
return self._apply_binary_method(method_name, lhs, rhs)
|
||||
|
||||
_method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'}
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
|
||||
def visit_NameConstant(self, node):
|
||||
return constexpr(node.value)
|
||||
|
||||
@@ -1013,7 +1080,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
evaluated = self.visit(value.value)
|
||||
if not _is_constexpr(evaluated):
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type " + str(type(evaluated)))
|
||||
None, node,
|
||||
"Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type "
|
||||
+ str(type(evaluated)))
|
||||
values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value)
|
||||
else:
|
||||
raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value)))
|
||||
@@ -1055,7 +1124,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
passed = _unwrap_if_constexpr(self.visit(node.args[0]))
|
||||
if not isinstance(passed, bool):
|
||||
raise NotImplementedError("Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values")
|
||||
raise NotImplementedError(
|
||||
"Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values"
|
||||
)
|
||||
if not passed:
|
||||
if arg_count == 1:
|
||||
message = ""
|
||||
@@ -1144,10 +1215,9 @@ def ast_to_ttir(fn, signature, specialization, constants, debug, target):
|
||||
file_name, begin_line = _get_fn_file_line(fn)
|
||||
|
||||
prototype = language.function_type([], arg_types)
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants,
|
||||
function_name=function_name, attributes=new_attrs,
|
||||
is_kernel=True, debug=debug, file_name=file_name, begin_line=begin_line,
|
||||
target=target)
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name,
|
||||
attributes=new_attrs, is_kernel=True, debug=debug, file_name=file_name,
|
||||
begin_line=begin_line, target=target)
|
||||
try:
|
||||
generator.visit(fn.parse())
|
||||
except CompilationError as e:
|
||||
|
||||
@@ -11,25 +11,21 @@ from typing import Any
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs,
|
||||
compile_ptx_to_cubin, get_env_vars, get_num_warps,
|
||||
get_shared_memory_size, ir, runtime,
|
||||
translate_llvmir_to_ptx,
|
||||
from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs, compile_ptx_to_cubin, get_env_vars,
|
||||
get_num_warps, get_shared_memory_size, ir, runtime, translate_llvmir_to_ptx,
|
||||
translate_triton_gpu_to_llvmir)
|
||||
from ..common.backend import get_backend, path_to_ptxas
|
||||
from ..common.backend import get_backend, get_cuda_version_key, path_to_ptxas
|
||||
from ..common.build import is_hip
|
||||
# from ..runtime import driver, jit, JITFunction
|
||||
# TODO: runtime.errors
|
||||
from ..runtime.autotuner import OutOfResources
|
||||
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
|
||||
from ..runtime.driver import driver
|
||||
from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device,
|
||||
get_device_capability, version_key)
|
||||
from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device, get_device_capability)
|
||||
from ..tools.disasm import get_sass
|
||||
from .code_generator import ast_to_ttir
|
||||
from .make_launcher import make_stub
|
||||
from .utils import (InfoFromBackendForTensorMap, TensorMapManager,
|
||||
get_ids_of_tensormaps, parse_tma_info)
|
||||
from .utils import (InfoFromBackendForTensorMap, TensorMapManager, get_ids_of_tensormaps, parse_tma_info)
|
||||
|
||||
CUDA_DEFAULT_WARP_SIZE = 32
|
||||
|
||||
@@ -45,6 +41,7 @@ def _is_cuda(target):
|
||||
|
||||
|
||||
class LazyDict(dict):
|
||||
|
||||
def __getitem__(self, key):
|
||||
val = dict.__getitem__(self, key)
|
||||
if callable(val):
|
||||
@@ -103,8 +100,13 @@ def ttir_to_ttgir(mod, num_warps, warpsize, num_ctas, target):
|
||||
return mod
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target,
|
||||
cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, matrix_inst_type):
|
||||
=======
|
||||
def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization,
|
||||
enable_persistent, optimize_epilogue):
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
is_cuda = _is_cuda(target)
|
||||
if is_cuda:
|
||||
capability = target.capability
|
||||
@@ -128,9 +130,13 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target,
|
||||
if optimize_epilogue:
|
||||
pm.add_tritongpu_optimize_epilogue_pass()
|
||||
pm.add_tritongpu_optimize_dot_operands_pass()
|
||||
<<<<<<< HEAD
|
||||
if num_stages == 0 and is_hip() and gpu_matrix_core_version() != 0:
|
||||
pm.add_tritongpu_stream_pipeline_pass()
|
||||
pm.add_canonicalizer_pass()
|
||||
=======
|
||||
pm.add_cse_pass()
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
ws_enabled = False
|
||||
# `num_warps` does not mean the total number of warps of a CTA when
|
||||
# warp specialization is enabled.
|
||||
@@ -174,6 +180,8 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target,
|
||||
if is_cuda and capability // 10 >= 9:
|
||||
pm.add_tritongpu_fence_insertion_pass()
|
||||
pm.add_tritongpu_ws_fixup_missing_attrs_pass()
|
||||
pm.add_tritongpu_optimize_thread_locality_pass()
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
@@ -197,6 +205,7 @@ def ttgir_to_llir(mod, extern_libs, target, tma_infos, waves_per_eu=0):
|
||||
|
||||
# PTX translation
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def ptx_get_version(cuda_version) -> int:
|
||||
'''
|
||||
@@ -261,7 +270,11 @@ def convert_type_repr(x):
|
||||
return x
|
||||
|
||||
|
||||
def make_hash(fn, target, env_vars, **kwargs):
|
||||
def make_hash(fn, target, env_vars, device_backend, **kwargs):
|
||||
if device_backend is None:
|
||||
version_key = get_cuda_version_key()
|
||||
else:
|
||||
version_key = device_backend.get_version_key()
|
||||
if isinstance(fn, JITFunction):
|
||||
configs = kwargs["configs"]
|
||||
signature = kwargs["signature"]
|
||||
@@ -275,16 +288,21 @@ def make_hash(fn, target, env_vars, **kwargs):
|
||||
enable_persistent = kwargs.get("enable_persistent", False)
|
||||
debug = kwargs.get("debug", False)
|
||||
# Get unique key for the compiled code
|
||||
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1), sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8))
|
||||
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1),
|
||||
sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8))
|
||||
configs_key = [get_conf_key(conf) for conf in configs]
|
||||
env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())]
|
||||
<<<<<<< HEAD
|
||||
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{matrix_instr_nonkdim}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}"
|
||||
=======
|
||||
key = f"{fn.cache_key}-{version_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}"
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
return hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
assert isinstance(fn, str)
|
||||
ignore_version = kwargs.get('ignore_version', False)
|
||||
if (ignore_version):
|
||||
return hashlib.md5((Path(fn).read_text()).encode("utf-8")).hexdigest()
|
||||
return hashlib.md5((Path(fn).read_text() + version_key()).encode("utf-8")).hexdigest()
|
||||
return hashlib.md5((Path(fn).read_text() + version_key).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
|
||||
@@ -321,12 +339,14 @@ else:
|
||||
|
||||
|
||||
def _get_jsonable_constants(constants):
|
||||
|
||||
def _is_jsonable(x):
|
||||
try:
|
||||
json.dumps(x)
|
||||
return True
|
||||
except (TypeError, OverflowError):
|
||||
return False
|
||||
|
||||
serialized_constants = {}
|
||||
for constant in constants:
|
||||
if _is_jsonable(constants[constant]):
|
||||
@@ -341,7 +361,9 @@ def parse_mlir_module(path, context):
|
||||
return module
|
||||
|
||||
|
||||
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"], defaults=[set(), set(), set(), set()])
|
||||
instance_descriptor = namedtuple("instance_descriptor",
|
||||
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"],
|
||||
defaults=[set(), set(), set(), set()])
|
||||
|
||||
|
||||
def is_hip():
|
||||
@@ -385,10 +407,16 @@ def get_arch_default_num_stages(device_type, capability=None):
|
||||
|
||||
|
||||
def add_cuda_stages(target, extern_libs, stages):
|
||||
<<<<<<< HEAD
|
||||
stages["ptx"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: llir_to_ptx(src, target))
|
||||
stages["cubin"] = (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ptx_to_cubin(src, target))
|
||||
=======
|
||||
|
||||
stages["ptx"] = (lambda path: Path(path).read_text(), lambda src: llir_to_ptx(src, target))
|
||||
stages["cubin"] = (lambda path: Path(path).read_bytes(), lambda src: ptx_to_cubin(src, target))
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
|
||||
|
||||
def compile(fn, **kwargs):
|
||||
@@ -434,7 +462,8 @@ def compile(fn, **kwargs):
|
||||
# build architecture descriptor
|
||||
if device_type == "cuda":
|
||||
_device_backend = get_backend(device_type)
|
||||
target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps, enable_fp_fusion=enable_fp_fusion)
|
||||
target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps,
|
||||
enable_fp_fusion=enable_fp_fusion)
|
||||
else:
|
||||
_device_backend = get_backend(device_type)
|
||||
assert _device_backend
|
||||
@@ -443,11 +472,12 @@ def compile(fn, **kwargs):
|
||||
# build compilation stages
|
||||
stages = dict()
|
||||
stages["ast"] = (lambda path: fn, None)
|
||||
stages["ttir"] = (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target))
|
||||
stages["ttir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttir(
|
||||
ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target))
|
||||
if is_cuda:
|
||||
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
|
||||
stages["ttgir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttgir(
|
||||
ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info,
|
||||
enable_warp_specialization, enable_persistent, optimize_epilogue))
|
||||
stages["llir"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, target, tma_infos))
|
||||
add_cuda_stages(target, extern_libs, stages)
|
||||
@@ -507,18 +537,21 @@ def compile(fn, **kwargs):
|
||||
if ir_name == 'ttgir':
|
||||
num_warps_matches = re.findall(ttgir_num_warps_pattern, src)
|
||||
assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps"
|
||||
assert "num_warps" not in kwargs or int(num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile"
|
||||
assert "num_warps" not in kwargs or int(
|
||||
num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile"
|
||||
num_warps = int(num_warps_matches[0])
|
||||
param_tys = [convert_type_repr(ty) for ty in types]
|
||||
signature = {k: v for k, v in enumerate(param_tys)}
|
||||
first_stage = list(stages.keys()).index(ir_name)
|
||||
|
||||
# create cache manager
|
||||
fn_cache_manager = get_cache_manager(make_hash(fn, target, get_env_vars(), **kwargs))
|
||||
fn_cache_manager = get_cache_manager(make_hash(fn, target, get_env_vars(), _device_backend, **kwargs))
|
||||
# managers used to dump and override IR for debugging
|
||||
enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1"
|
||||
fn_override_manager = get_override_manager(make_hash(fn, target, get_env_vars(), **kwargs, ignore_version=True))
|
||||
fn_dump_manager = get_dump_manager(make_hash(fn, target, get_env_vars(), **kwargs, ignore_version=True))
|
||||
fn_override_manager = get_override_manager(
|
||||
make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True))
|
||||
fn_dump_manager = get_dump_manager(
|
||||
make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True))
|
||||
|
||||
# determine name and extension type of provided function
|
||||
if isinstance(fn, JITFunction):
|
||||
@@ -531,9 +564,7 @@ def compile(fn, **kwargs):
|
||||
metadata_filename = f"{name}.json"
|
||||
|
||||
# The group is addressed by the metadata
|
||||
metadata_group = fn_cache_manager.get_group(
|
||||
metadata_filename
|
||||
) or {}
|
||||
metadata_group = fn_cache_manager.get_group(metadata_filename) or {}
|
||||
|
||||
metadata_path = metadata_group.get(metadata_filename)
|
||||
|
||||
@@ -541,9 +572,9 @@ def compile(fn, **kwargs):
|
||||
with open(metadata_path) as f:
|
||||
metadata = json.load(f)
|
||||
if 'tensormaps_info' in metadata:
|
||||
metadata['tensormaps_info'] = [
|
||||
InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']]
|
||||
metadata['tensormaps_info'] = [InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']]
|
||||
else:
|
||||
<<<<<<< HEAD
|
||||
metadata = {"num_warps": num_warps,
|
||||
"warp_size": warp_size,
|
||||
"num_ctas": num_ctas,
|
||||
@@ -555,6 +586,18 @@ def compile(fn, **kwargs):
|
||||
"constants": _get_jsonable_constants(constants),
|
||||
"debug": debug,
|
||||
"target": target, }
|
||||
=======
|
||||
metadata = {
|
||||
"num_warps": num_warps,
|
||||
"num_ctas": num_ctas,
|
||||
"num_stages": num_stages,
|
||||
"enable_warp_specialization": enable_warp_specialization,
|
||||
"enable_persistent": enable_persistent,
|
||||
"constants": _get_jsonable_constants(constants),
|
||||
"debug": debug,
|
||||
"target": target,
|
||||
}
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
metadata.update(get_env_vars())
|
||||
if ext == "ptx":
|
||||
assert "shared" in kwargs, "ptx compilation must provide shared memory size"
|
||||
@@ -626,10 +669,7 @@ def compile(fn, **kwargs):
|
||||
|
||||
ids_of_folded_args = tuple([int(k) for k in configs[0].ids_of_folded_args]) if isinstance(fn, JITFunction) else ()
|
||||
if "clusterDims" not in metadata:
|
||||
metadata["clusterDims"] = [
|
||||
cluster_info.clusterDimX,
|
||||
cluster_info.clusterDimY,
|
||||
cluster_info.clusterDimZ]
|
||||
metadata["clusterDims"] = [cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ]
|
||||
|
||||
if len(tma_infos) > 0:
|
||||
metadata["tensormaps_info"] = parse_tma_info(tma_infos, ids_of_folded_args)
|
||||
@@ -643,7 +683,10 @@ def compile(fn, **kwargs):
|
||||
fn.tensormaps_info = metadata["tensormaps_info"]
|
||||
|
||||
ids_of_const_exprs = tuple(fn.constexprs) if isinstance(fn, JITFunction) else ()
|
||||
ids = {"ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs": ids_of_const_exprs}
|
||||
ids = {
|
||||
"ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs":
|
||||
ids_of_const_exprs
|
||||
}
|
||||
# cache manager
|
||||
if is_cuda:
|
||||
so_path = make_stub(name, signature, constants, ids, enable_warp_specialization=enable_warp_specialization)
|
||||
@@ -651,7 +694,8 @@ def compile(fn, **kwargs):
|
||||
so_path = _device_backend.make_launcher_stub(name, signature, constants, ids)
|
||||
# write-back metadata, if it didn't come from the cache
|
||||
if metadata_path is None:
|
||||
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, binary=False)
|
||||
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename,
|
||||
binary=False)
|
||||
fn_cache_manager.put_group(metadata_filename, metadata_group)
|
||||
|
||||
# return handle to compiled kernel
|
||||
@@ -701,10 +745,7 @@ class CompiledKernel:
|
||||
|
||||
if self.device_type in ["cuda"]:
|
||||
device = get_current_device()
|
||||
bin_path = {
|
||||
driver.HIP: "hsaco_path",
|
||||
driver.CUDA: "cubin"
|
||||
}[driver.backend]
|
||||
bin_path = {driver.HIP: "hsaco_path", driver.CUDA: "cubin"}[driver.backend]
|
||||
max_shared = driver.utils.get_device_properties(device)["max_shared_mem"]
|
||||
fn_load_binary = driver.utils.load_binary
|
||||
else:
|
||||
@@ -752,4 +793,5 @@ class CompiledKernel:
|
||||
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.num_ctas, self.clusterDims[0],
|
||||
self.clusterDims[1], self.clusterDims[2], self.shared, stream, self.cu_function,
|
||||
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand)
|
||||
|
||||
return runner
|
||||
|
||||
@@ -3,9 +3,9 @@ import os
|
||||
import tempfile
|
||||
|
||||
from ..common import _build
|
||||
from ..common.backend import get_cuda_version_key
|
||||
from ..common.build import is_hip
|
||||
from ..runtime.cache import get_cache_manager
|
||||
from ..runtime.jit import version_key
|
||||
from .utils import generate_cu_signature
|
||||
|
||||
# ----- stub --------
|
||||
@@ -23,7 +23,7 @@ def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
|
||||
|
||||
def make_stub(name, signature, constants, ids, **kwargs):
|
||||
# name of files that are cached
|
||||
so_cache_key = make_so_cache_key(version_key(), signature, constants, ids, **kwargs)
|
||||
so_cache_key = make_so_cache_key(get_cuda_version_key(), signature, constants, ids, **kwargs)
|
||||
so_cache_manager = get_cache_manager(so_cache_key)
|
||||
so_name = f"{name}.so"
|
||||
# retrieve stub from cache if it exists
|
||||
@@ -40,6 +40,7 @@ def make_stub(name, signature, constants, ids, **kwargs):
|
||||
else:
|
||||
return cache_path
|
||||
|
||||
|
||||
# ----- source code generation --------
|
||||
|
||||
|
||||
@@ -100,7 +101,10 @@ def generate_launcher(constants, signature, ids):
|
||||
|
||||
# generate glue code
|
||||
folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']]
|
||||
params = [i for i in signature.keys() if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs)]
|
||||
params = [
|
||||
i for i in signature.keys()
|
||||
if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs)
|
||||
]
|
||||
src = f"""
|
||||
#include \"cuda.h\"
|
||||
#include <stdbool.h>
|
||||
|
||||
@@ -158,19 +158,21 @@ class InfoFromBackendForTensorMap:
|
||||
|
||||
# dtype:cuda.CUtensorMapDataType | int
|
||||
def bytes_from_type(self, dtype):
|
||||
return {driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT8"]: 1,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT16"]: 2,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT64"]: 8,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT64"]: 8,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT16"]: 2,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT64"]: 8,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_BFLOAT16"]: 2,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ"]: 4}[dtype]
|
||||
return {
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT8"]: 1,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT16"]: 2,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT64"]: 8,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT64"]: 8,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT16"]: 2,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT64"]: 8,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_BFLOAT16"]: 2,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ"]: 4
|
||||
}[dtype]
|
||||
|
||||
def getTensorMapDataType(self):
|
||||
return self.tensorDataType
|
||||
@@ -259,22 +261,29 @@ class InfoFromBackendForTensorMap:
|
||||
self.getInterleave(),
|
||||
self.getSwizzle(),
|
||||
self.getL2Promotion(),
|
||||
self.getOobFill()
|
||||
self.getOobFill(),
|
||||
)
|
||||
|
||||
# make hashable to use as partial key in cache
|
||||
def __hash__(self):
|
||||
return hash((self.ids_of_folded_args, self.globalAddressArgIdx, tuple(self.globalDimsArgIdx), tuple(self.globalStridesArgIdx), self.tensorDataType,
|
||||
self.tensorRank, tuple(self.boxDims), tuple(self.elementStrides), self.interleave, self.swizzle, self.l2Promotion, self.oobFill))
|
||||
return hash((self.ids_of_folded_args, self.globalAddressArgIdx, tuple(self.globalDimsArgIdx),
|
||||
tuple(self.globalStridesArgIdx), self.tensorDataType, self.tensorRank, tuple(self.boxDims),
|
||||
tuple(self.elementStrides), self.interleave, self.swizzle, self.l2Promotion, self.oobFill))
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
return (self.ids_of_folded_args, self.globalAddressArgIdx, self.globalDimsArgIdx, self.globalStridesArgIdx, self.tensorDataType, self.tensorRank, self.boxDims, self.elementStrides, self.interleave, self.swizzle, self.l2Promotion, self.oobFill) == (
|
||||
other.ids_of_folded_args, other.globalAddressArgIdx, other.globalDimsArgIdx, other.globalStridesArgIdx, other.tensorDataType, other.tensorRank, other.boxDims, other.elementStrides, other.interleave, other.swizzle, other.l2Promotion, other.oobFill)
|
||||
return (self.ids_of_folded_args, self.globalAddressArgIdx, self.globalDimsArgIdx, self.globalStridesArgIdx,
|
||||
self.tensorDataType, self.tensorRank, self.boxDims, self.elementStrides, self.interleave, self.swizzle,
|
||||
self.l2Promotion,
|
||||
self.oobFill) == (other.ids_of_folded_args, other.globalAddressArgIdx, other.globalDimsArgIdx,
|
||||
other.globalStridesArgIdx, other.tensorDataType, other.tensorRank, other.boxDims,
|
||||
other.elementStrides, other.interleave, other.swizzle, other.l2Promotion,
|
||||
other.oobFill)
|
||||
|
||||
|
||||
class TensorMapManager:
|
||||
|
||||
def __init__(self):
|
||||
self.tensormaps_device = {}
|
||||
|
||||
@@ -286,8 +295,7 @@ class TensorMapManager:
|
||||
t_tensormap = e.tensormap(args)
|
||||
TENSORMAP_SIZE_IN_BYTES = 128
|
||||
t_tensormap_device = driver.utils.cuMemAlloc(TENSORMAP_SIZE_IN_BYTES)
|
||||
driver.utils.cuMemcpyHtoD(
|
||||
t_tensormap_device, t_tensormap, TENSORMAP_SIZE_IN_BYTES)
|
||||
driver.utils.cuMemcpyHtoD(t_tensormap_device, t_tensormap, TENSORMAP_SIZE_IN_BYTES)
|
||||
self.tensormaps_device[key] = t_tensormap_device
|
||||
return int(self.tensormaps_device[key])
|
||||
|
||||
|
||||
@@ -111,7 +111,6 @@ from .random import (
|
||||
uint32_to_uniform_float,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TRITON_MAX_TENSOR_NUMEL",
|
||||
"abs",
|
||||
|
||||
@@ -22,10 +22,8 @@ def builtin(fn: T) -> T:
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if "_builder" not in kwargs or kwargs["_builder"] is None:
|
||||
raise ValueError(
|
||||
"Did you forget to add @triton.jit ? "
|
||||
"(`_builder` argument must be provided outside of JIT functions.)"
|
||||
)
|
||||
raise ValueError("Did you forget to add @triton.jit ? "
|
||||
"(`_builder` argument must be provided outside of JIT functions.)")
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
setattr(wrapper, TRITON_BUILTIN, True)
|
||||
@@ -54,7 +52,7 @@ def _to_tensor(x, builder):
|
||||
else:
|
||||
raise RuntimeError(f'Nonrepresentable integer {x}.')
|
||||
elif isinstance(x, float):
|
||||
min_float32 = 2 ** -126
|
||||
min_float32 = 2**-126
|
||||
max_float32 = (2 - 2**-23) * 2**127
|
||||
abs_x = __builtins__['abs'](x)
|
||||
if abs_x == float("inf") or\
|
||||
@@ -243,7 +241,7 @@ class dtype:
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.name,))
|
||||
return hash((self.name, ))
|
||||
|
||||
@property
|
||||
def scalar(self):
|
||||
@@ -297,6 +295,7 @@ class dtype:
|
||||
|
||||
|
||||
class pointer_type(dtype):
|
||||
|
||||
def __init__(self, element_ty: dtype, address_space: int = 1):
|
||||
if not isinstance(element_ty, dtype):
|
||||
raise TypeError('element_ty is a {type(element_ty).__name__}.')
|
||||
@@ -331,6 +330,7 @@ class pointer_type(dtype):
|
||||
|
||||
|
||||
class block_type(dtype):
|
||||
|
||||
def __init__(self, element_ty: dtype, shape: List):
|
||||
self.element_ty = element_ty
|
||||
|
||||
@@ -381,6 +381,7 @@ class block_type(dtype):
|
||||
|
||||
|
||||
class function_type(dtype):
|
||||
|
||||
def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None:
|
||||
self.ret_types = ret_types
|
||||
self.param_types = param_types
|
||||
@@ -531,7 +532,7 @@ class constexpr:
|
||||
return constexpr(~self.value)
|
||||
|
||||
def __pow__(self, other):
|
||||
return constexpr(self.value ** other.value)
|
||||
return constexpr(self.value**other.value)
|
||||
|
||||
def __rshift__(self, other):
|
||||
return constexpr(self.value >> other.value)
|
||||
@@ -547,6 +548,7 @@ class constexpr:
|
||||
|
||||
|
||||
class tensor:
|
||||
|
||||
def __init__(self, handle, type: dtype):
|
||||
# IR handle
|
||||
self.handle = handle
|
||||
@@ -740,11 +742,21 @@ class tensor:
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.equal(self, other, _builder)
|
||||
|
||||
@builtin
|
||||
def __req__(self, other, _builder=None):
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.equal(other, self, _builder)
|
||||
|
||||
@builtin
|
||||
def __ne__(self, other, _builder=None):
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.not_equal(self, other, _builder)
|
||||
|
||||
@builtin
|
||||
def __rne__(self, other, _builder=None):
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.not_equal(other, self, _builder)
|
||||
|
||||
@builtin
|
||||
def logical_and(self, other, _builder=None):
|
||||
other = _to_tensor(other, _builder)
|
||||
@@ -1023,6 +1035,7 @@ def expand_dims(input, axis, _builder=None):
|
||||
ret = semantic.expand_dims(ret, a, _builder)
|
||||
return ret
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Linear Algebra
|
||||
# -----------------------
|
||||
@@ -1171,6 +1184,7 @@ def advance(base: tensor, offsets, _builder=None):
|
||||
"""
|
||||
return semantic.advance(base, offsets, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Atomic Memory Operations
|
||||
# -----------------------
|
||||
@@ -1196,6 +1210,9 @@ def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]:
|
||||
:param sem: Memory semantics to use ("ACQUIRE_RELEASE" (default),
|
||||
"ACQUIRE", "RELEASE", or "RELAXED")
|
||||
:type sem: str
|
||||
:param scope: Scope of threads that observe synchronizing effect of the
|
||||
atomic operation ("GPU" (default), "CTA", or "SYSTEM")
|
||||
:type scope: str
|
||||
"""
|
||||
func.__doc__ = docstr
|
||||
return func
|
||||
@@ -1205,73 +1222,82 @@ def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]:
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("compare-and-swap", has_cmp=True)
|
||||
def atomic_cas(pointer, cmp, val, sem=None, _builder=None):
|
||||
def atomic_cas(pointer, cmp, val, sem=None, scope=None, _builder=None):
|
||||
cmp = _to_tensor(cmp, _builder)
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_cas(pointer, cmp, val, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_cas(pointer, cmp, val, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("exchange")
|
||||
def atomic_xchg(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_xchg(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_xchg(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("add")
|
||||
def atomic_add(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_add(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_add(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_add(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("max")
|
||||
def atomic_max(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_max(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_max(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_max(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("min")
|
||||
def atomic_min(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_min(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_min(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_min(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("logical and")
|
||||
def atomic_and(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_and(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_and(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_and(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("logical or")
|
||||
def atomic_or(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_or(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_or(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_or(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("logical xor")
|
||||
def atomic_xor(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_xor(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_xor(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Conditioning
|
||||
# -----------------------
|
||||
|
||||
|
||||
@builtin
|
||||
def where(condition, x, y, _builder=None):
|
||||
"""
|
||||
@@ -1299,6 +1325,7 @@ def where(condition, x, y, _builder=None):
|
||||
# Math
|
||||
# -----------------------
|
||||
|
||||
|
||||
@builtin
|
||||
def umulhi(x, y, _builder=None):
|
||||
"""
|
||||
@@ -1392,6 +1419,7 @@ def abs(x, _builder=None):
|
||||
# Reductions
|
||||
# -----------------------
|
||||
|
||||
|
||||
def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]:
|
||||
|
||||
def _decorator(func: T) -> T:
|
||||
@@ -1430,8 +1458,7 @@ def reduce(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
|
||||
"""
|
||||
if isinstance(input, tensor):
|
||||
return reduce((input,), axis, combine_fn,
|
||||
_builder=_builder, _generator=_generator)[0]
|
||||
return reduce((input, ), axis, combine_fn, _builder=_builder, _generator=_generator)[0]
|
||||
|
||||
def make_combine_region(reduce_op):
|
||||
in_scalar_tys = [t.type.scalar for t in input]
|
||||
@@ -1441,14 +1468,14 @@ def reduce(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
with _insertion_guard(_builder):
|
||||
param_types = [ty.to_ir(_builder) for ty in prototype.param_types]
|
||||
block = _builder.create_block_with_parent(region, param_types)
|
||||
args = [tensor(block.arg(i), ty)
|
||||
for i, ty in enumerate(prototype.param_types)]
|
||||
args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)]
|
||||
results = _generator.call_JitFunction(combine_fn, args, kwargs={})
|
||||
if isinstance(results, tensor):
|
||||
handles = [results.handle]
|
||||
else:
|
||||
handles = [r.handle for r in results]
|
||||
_builder.create_reduce_ret(*handles)
|
||||
|
||||
if axis is not None:
|
||||
axis = _constexpr_to_value(axis)
|
||||
return semantic.reduction(input, axis, make_combine_region, _builder)
|
||||
@@ -1483,8 +1510,7 @@ def _reduce_with_indices(input, axis, combine_fn, _builder=None, _generator=None
|
||||
index = expand_dims(index, axes_to_expand, _builder=_builder)
|
||||
index = broadcast_to(index, input.shape, _builder=_builder)
|
||||
|
||||
rvalue, rindices = reduce((input, index), axis, combine_fn,
|
||||
_builder=_builder, _generator=_generator)
|
||||
rvalue, rindices = reduce((input, index), axis, combine_fn, _builder=_builder, _generator=_generator)
|
||||
return rvalue, rindices
|
||||
|
||||
|
||||
@@ -1492,6 +1518,7 @@ def _reduce_with_indices(input, axis, combine_fn, _builder=None, _generator=None
|
||||
# Scans
|
||||
# -----------------------
|
||||
|
||||
|
||||
def _add_scan_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]:
|
||||
|
||||
def _decorator(func: T) -> T:
|
||||
@@ -1516,8 +1543,7 @@ def associative_scan(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
|
||||
"""
|
||||
if isinstance(input, tensor):
|
||||
return associative_scan((input,), axis, combine_fn,
|
||||
_builder=_builder, _generator=_generator)[0]
|
||||
return associative_scan((input, ), axis, combine_fn, _builder=_builder, _generator=_generator)[0]
|
||||
|
||||
def make_combine_region(scan_op):
|
||||
in_scalar_tys = [t.type.scalar for t in input]
|
||||
@@ -1527,17 +1553,18 @@ def associative_scan(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
with _insertion_guard(_builder):
|
||||
param_types = [ty.to_ir(_builder) for ty in prototype.param_types]
|
||||
block = _builder.create_block_with_parent(region, param_types)
|
||||
args = [tensor(block.arg(i), ty)
|
||||
for i, ty in enumerate(prototype.param_types)]
|
||||
args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)]
|
||||
results = _generator.call_JitFunction(combine_fn, args, kwargs={})
|
||||
if isinstance(results, tensor):
|
||||
handles = [results.handle]
|
||||
else:
|
||||
handles = [r.handle for r in results]
|
||||
_builder.create_scan_ret(*handles)
|
||||
|
||||
axis = _constexpr_to_value(axis)
|
||||
return semantic.associative_scan(input, axis, make_combine_region, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Compiler Hint Ops
|
||||
# -----------------------
|
||||
@@ -1600,6 +1627,8 @@ def max_constancy(input, values, _builder=None):
|
||||
raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
|
||||
values = [x.value for x in values]
|
||||
return semantic.max_constancy(input, values)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Debugging functions
|
||||
# -----------------------
|
||||
@@ -1739,12 +1768,12 @@ def inline_asm_elementwise(asm: str, constraints: str, args: list, dtype, is_pur
|
||||
broadcast_arg = dispatch_args[0]
|
||||
# Get the broadcast shape over all the arguments
|
||||
for i, item in enumerate(dispatch_args):
|
||||
_, broadcast_arg = semantic.binary_op_type_checking_impl(
|
||||
item, broadcast_arg, _builder, arithmetic_check=False)
|
||||
_, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder,
|
||||
arithmetic_check=False)
|
||||
# Change the shape of each argument based on the broadcast shape
|
||||
for i in range(len(dispatch_args)):
|
||||
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
|
||||
dispatch_args[i], broadcast_arg, _builder, arithmetic_check=False)
|
||||
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder,
|
||||
arithmetic_check=False)
|
||||
ret_shape = broadcast_arg.shape
|
||||
res_ty = block_type(dtype, ret_shape)
|
||||
call = _builder.create_inline_asm(asm, constraints, [t.handle for t in args], res_ty.to_ir(_builder), is_pure, pack)
|
||||
@@ -1757,7 +1786,6 @@ def inline_asm_elementwise(asm: str, constraints: str, args: list, dtype, is_pur
|
||||
|
||||
|
||||
class static_range:
|
||||
|
||||
"""
|
||||
Iterator that counts upward forever.
|
||||
|
||||
@@ -1801,7 +1829,9 @@ class static_range:
|
||||
# Extern functions
|
||||
# -----------------------
|
||||
|
||||
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, is_pure: bool, _builder=None):
|
||||
|
||||
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple,
|
||||
is_pure: bool, _builder=None):
|
||||
'''
|
||||
Dispatch a function to a library
|
||||
:param func: the function to dispatch
|
||||
@@ -1843,7 +1873,8 @@ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dic
|
||||
return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder), is_pure), ret_type)
|
||||
|
||||
|
||||
def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, _builder=None):
|
||||
def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool,
|
||||
_builder=None):
|
||||
'''
|
||||
Dispatch an elementwise function to a library
|
||||
:param lib_name: the name of the library
|
||||
@@ -1872,12 +1903,12 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol
|
||||
broadcast_arg = dispatch_args[0]
|
||||
# Get the broadcast shape over all the arguments
|
||||
for i, item in enumerate(dispatch_args):
|
||||
_, broadcast_arg = semantic.binary_op_type_checking_impl(
|
||||
item, broadcast_arg, _builder, arithmetic_check=arithmetic_check)
|
||||
_, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder,
|
||||
arithmetic_check=arithmetic_check)
|
||||
# Change the shape of each argument based on the broadcast shape
|
||||
for i in range(len(dispatch_args)):
|
||||
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
|
||||
dispatch_args[i], broadcast_arg, _builder, arithmetic_check=arithmetic_check)
|
||||
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder,
|
||||
arithmetic_check=arithmetic_check)
|
||||
if not all_scalar:
|
||||
ret_shape = broadcast_arg.shape
|
||||
func = getattr(_builder, "create_extern_elementwise")
|
||||
|
||||
@@ -3,16 +3,14 @@ from .. import core
|
||||
|
||||
@core.extern
|
||||
def globaltimer(_builder=None):
|
||||
return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [],
|
||||
dtype=core.int64, is_pure=False,
|
||||
pack=1, _builder=_builder)
|
||||
return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], dtype=core.int64, is_pure=False, pack=1,
|
||||
_builder=_builder)
|
||||
|
||||
|
||||
@core.extern
|
||||
def smid(_builder=None):
|
||||
return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [],
|
||||
dtype=core.int32, is_pure=True,
|
||||
pack=1, _builder=_builder)
|
||||
return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1,
|
||||
_builder=_builder)
|
||||
|
||||
|
||||
@core.builtin
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -91,6 +91,7 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
# two_to_the_minus_32: tl.constexpr = 2.328306e-10
|
||||
# return x * two_to_the_minus_32
|
||||
|
||||
|
||||
@jit
|
||||
def uint32_to_uniform_float(x):
|
||||
"""
|
||||
@@ -134,6 +135,7 @@ def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
u4 = uint32_to_uniform_float(i4)
|
||||
return u1, u2, u3, u4
|
||||
|
||||
|
||||
# -------------------
|
||||
# randn
|
||||
# -------------------
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -123,6 +123,7 @@ def maximum(x, y):
|
||||
"""
|
||||
return math.max(x, y)
|
||||
|
||||
|
||||
# max and argmax
|
||||
|
||||
|
||||
@@ -149,8 +150,7 @@ def _argmax_combine_tie_break_fast(value1, index1, value2, index2):
|
||||
|
||||
|
||||
@jit
|
||||
@core._add_reduction_docstr("maximum",
|
||||
return_indices_arg="return_indices",
|
||||
@core._add_reduction_docstr("maximum", return_indices_arg="return_indices",
|
||||
tie_break_arg="return_indices_tie_break_left")
|
||||
def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True):
|
||||
input = core._promote_reduction_input(input)
|
||||
@@ -175,6 +175,7 @@ def argmax(input, axis, tie_break_left=True):
|
||||
(_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left)
|
||||
return ret
|
||||
|
||||
|
||||
# min and argmin
|
||||
|
||||
|
||||
@@ -201,8 +202,7 @@ def _argmin_combine_tie_break_fast(value1, index1, value2, index2):
|
||||
|
||||
|
||||
@jit
|
||||
@core._add_reduction_docstr("minimum",
|
||||
return_indices_arg="return_indices",
|
||||
@core._add_reduction_docstr("minimum", return_indices_arg="return_indices",
|
||||
tie_break_arg="return_indices_tie_break_left")
|
||||
def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True):
|
||||
input = core._promote_reduction_input(input)
|
||||
@@ -222,8 +222,7 @@ def min(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr
|
||||
|
||||
|
||||
@jit
|
||||
@core._add_reduction_docstr("minimum index",
|
||||
tie_break_arg="tie_break_left")
|
||||
@core._add_reduction_docstr("minimum index", tie_break_arg="tie_break_left")
|
||||
def argmin(input, axis, tie_break_left=True):
|
||||
_, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left)
|
||||
return ret
|
||||
@@ -233,6 +232,7 @@ def argmin(input, axis, tie_break_left=True):
|
||||
def _sum_combine(a, b):
|
||||
return a + b
|
||||
|
||||
|
||||
# sum
|
||||
|
||||
|
||||
@@ -247,6 +247,7 @@ def sum(input, axis=None):
|
||||
def _xor_combine(a, b):
|
||||
return a ^ b
|
||||
|
||||
|
||||
# xor sum
|
||||
|
||||
|
||||
@@ -258,8 +259,8 @@ def xor_sum(input, axis=None, _builder=None, _generator=None):
|
||||
raise ValueError("xor_sum only supported for integers")
|
||||
|
||||
input = core._promote_reduction_input(input, _builder=_builder)
|
||||
return core.reduce(input, axis, _xor_combine,
|
||||
_builder=_builder, _generator=_generator)
|
||||
return core.reduce(input, axis, _xor_combine, _builder=_builder, _generator=_generator)
|
||||
|
||||
|
||||
# cumsum
|
||||
|
||||
@@ -271,6 +272,7 @@ def cumsum(input, axis=0):
|
||||
input = core._promote_reduction_input(input)
|
||||
return core.associative_scan(input, axis, _sum_combine)
|
||||
|
||||
|
||||
# cumprod
|
||||
|
||||
|
||||
|
||||
@@ -17,15 +17,14 @@ from ... import language as tl
|
||||
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
|
||||
})
|
||||
@jit
|
||||
def _sdd_kernel(
|
||||
A, B, C,
|
||||
stride_za, stride_ha, stride_ma, stride_ak,
|
||||
stride_zb, stride_hb, stride_bk, stride_nb,
|
||||
stride_zc, stride_hc, stride_mc, stride_nc,
|
||||
K, grid_offset, lut,
|
||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
||||
BLOCK: tl.constexpr, EVEN_K: tl.constexpr
|
||||
):
|
||||
def _sdd_kernel(A, B, C, #
|
||||
stride_za, stride_ha, stride_ma, stride_ak, #
|
||||
stride_zb, stride_hb, stride_bk, stride_nb, #
|
||||
stride_zc, stride_hc, stride_mc, stride_nc, #
|
||||
K, grid_offset, lut, #
|
||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, #
|
||||
BLOCK: tl.constexpr, EVEN_K: tl.constexpr #
|
||||
):
|
||||
# ------------ #
|
||||
# - Prologue - #
|
||||
# ------------ #
|
||||
@@ -104,13 +103,13 @@ def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=
|
||||
c = out
|
||||
grid = [c.shape[1], 1, c.shape[0]]
|
||||
_sdd_kernel[grid](
|
||||
a, b, c,
|
||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
|
||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
||||
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
|
||||
Ka, 0, lut,
|
||||
TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4,
|
||||
num_warps=4,
|
||||
a, b, c, #
|
||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), #
|
||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), #
|
||||
c.stride(0), c.stride(1), c.stride(2), c.stride(3), #
|
||||
Ka, 0, lut, #
|
||||
TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4, #
|
||||
num_warps=4 #
|
||||
)
|
||||
return c
|
||||
|
||||
@@ -120,6 +119,7 @@ def sdd_lut(layout, block, device):
|
||||
lut = lut.contiguous()
|
||||
return lut, None
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Dense = Sparse x Dense (DSD)
|
||||
# This operation uses a look-up table that contains pre-computed pointer increments
|
||||
@@ -128,15 +128,14 @@ def sdd_lut(layout, block, device):
|
||||
|
||||
|
||||
@jit
|
||||
def _dsd_kernel(
|
||||
A, B, C,
|
||||
stride_az, stride_ha, stride_am, stride_ak,
|
||||
stride_zb, stride_hb, stride_bk, stride_bn,
|
||||
stride_zc, stride_hc, stride_cm, stride_cn,
|
||||
DS0, DS1, lut,
|
||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr
|
||||
):
|
||||
def _dsd_kernel(A, B, C, #
|
||||
stride_az, stride_ha, stride_am, stride_ak, #
|
||||
stride_zb, stride_hb, stride_bk, stride_bn, #
|
||||
stride_zc, stride_hc, stride_cm, stride_cn, #
|
||||
DS0, DS1, lut, #
|
||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, #
|
||||
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr #
|
||||
):
|
||||
# ------------ #
|
||||
# - Prologue - #
|
||||
# ------------ #
|
||||
@@ -229,13 +228,13 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=N
|
||||
# compute output
|
||||
grid = lambda meta: [cdiv(BS3, meta['TILE_N']), width, BS0]
|
||||
_dsd_kernel[grid](
|
||||
a, b, c,
|
||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
|
||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
||||
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
|
||||
BS3, AS1, lut,
|
||||
TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
|
||||
num_warps=4, GROUP_SIZE_M=4,
|
||||
a, b, c, #
|
||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), #
|
||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), #
|
||||
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), #
|
||||
BS3, AS1, lut, #
|
||||
TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4, #
|
||||
num_warps=4, GROUP_SIZE_M=4 #
|
||||
)
|
||||
# exit()
|
||||
return c
|
||||
@@ -337,6 +336,7 @@ def dsd_lut(layout, block, step, trans, device):
|
||||
# create locks
|
||||
return lut, width
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Dense = Dense x Sparse (DDS)
|
||||
# -----------------------------
|
||||
@@ -346,6 +346,7 @@ def dsd_lut(layout, block, step, trans, device):
|
||||
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
|
||||
return dsd_matmul(b, a, not trans_b, not trans_a, not trans_c, spdims, block, lut, width, out=out)
|
||||
|
||||
|
||||
##############
|
||||
# MAIN API #
|
||||
##############
|
||||
@@ -356,10 +357,8 @@ class _matmul(torch.autograd.Function):
|
||||
fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul}
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block,
|
||||
c_lut, c_width, da_lut, da_width, db_lut, db_width, out
|
||||
):
|
||||
def forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_width, da_lut, da_width, db_lut,
|
||||
db_width, out):
|
||||
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out)
|
||||
# save for backward
|
||||
ctx.save_for_backward(a, b)
|
||||
@@ -385,15 +384,13 @@ class _matmul(torch.autograd.Function):
|
||||
# gradients w.r.t. a
|
||||
if ctx.needs_input_grad[0]:
|
||||
mode_da = mode[1] + mode[0] + mode[2]
|
||||
da = _matmul.fn[mode_da](
|
||||
dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_width,
|
||||
)
|
||||
da = _matmul.fn[mode_da](dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block,
|
||||
ctx.da_lut, ctx.da_width)
|
||||
# gradients w.r.t. b
|
||||
if ctx.needs_input_grad[1]:
|
||||
mode_db = mode[2] + mode[1] + mode[0]
|
||||
db = _matmul.fn[mode_db](
|
||||
a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_width,
|
||||
)
|
||||
db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block,
|
||||
ctx.db_lut, ctx.db_width)
|
||||
dout = dc if ctx.has_out else None
|
||||
return da, db, None, None, None, \
|
||||
None, None, None, None, \
|
||||
@@ -427,11 +424,9 @@ class matmul:
|
||||
self.db_lut, self.db_width = sdd_lut(layout, block, device)
|
||||
|
||||
def __call__(self, a, b, out=None):
|
||||
c = _matmul.apply(
|
||||
a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block,
|
||||
self.c_lut, self.c_width,
|
||||
self.da_lut, self.da_width,
|
||||
self.db_lut, self.db_width,
|
||||
out
|
||||
)
|
||||
c = _matmul.apply(a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, #
|
||||
self.c_lut, self.c_width, #
|
||||
self.da_lut, self.da_width, #
|
||||
self.db_lut, self.db_width, #
|
||||
out)
|
||||
return c
|
||||
|
||||
@@ -18,14 +18,13 @@ def num_warps(n):
|
||||
|
||||
|
||||
@jit
|
||||
def _blocksparse_softmax_fwd(
|
||||
Out, A, stride_xz, LUT,
|
||||
R, extent, stride_zr, stride_hr, # relative attention
|
||||
scale, is_causal,
|
||||
ROW_SIZE: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
IS_DENSE: tl.constexpr,
|
||||
):
|
||||
def _blocksparse_softmax_fwd(Out, A, stride_xz, LUT, #
|
||||
R, extent, stride_zr, stride_hr, # relative attention
|
||||
scale, is_causal, #
|
||||
ROW_SIZE: tl.constexpr, #
|
||||
BLOCK_SIZE: tl.constexpr, #
|
||||
IS_DENSE: tl.constexpr #
|
||||
):
|
||||
h = tl.program_id(0)
|
||||
m = tl.program_id(1)
|
||||
z = tl.program_id(2)
|
||||
@@ -73,18 +72,16 @@ def _blocksparse_softmax_fwd(
|
||||
|
||||
|
||||
@jit
|
||||
def _blocksparse_softmax_bwd(
|
||||
DA, stride_zdx,
|
||||
DOut, stride_zdout,
|
||||
Out, stride_zout,
|
||||
scale,
|
||||
LUT,
|
||||
DR, extent, stride_zr, stride_hr, stride_er,
|
||||
is_causal,
|
||||
ROW_SIZE: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
IS_DENSE: tl.constexpr,
|
||||
):
|
||||
def _blocksparse_softmax_bwd(DA, stride_zdx, #
|
||||
DOut, stride_zdout, #
|
||||
Out, stride_zout, #
|
||||
scale, #
|
||||
LUT, #
|
||||
DR, extent, stride_zr, stride_hr, stride_er, #
|
||||
is_causal, #
|
||||
ROW_SIZE: tl.constexpr, #
|
||||
BLOCK_SIZE: tl.constexpr, #
|
||||
IS_DENSE: tl.constexpr):
|
||||
h = tl.program_id(0)
|
||||
m = tl.program_id(1)
|
||||
z = tl.program_id(2)
|
||||
@@ -133,6 +130,7 @@ def _blocksparse_softmax_bwd(
|
||||
|
||||
|
||||
class _softmax(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def make_lut(layout, block, device):
|
||||
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
|
||||
@@ -151,10 +149,7 @@ class _softmax(torch.autograd.Function):
|
||||
return lut, int(total_sizes.max())
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, a, scale, rel_logits, is_causal,
|
||||
spdims, block, lut, maxlut, is_dense
|
||||
):
|
||||
def forward(ctx, a, scale, rel_logits, is_causal, spdims, block, lut, maxlut, is_dense):
|
||||
if scale is not None and isinstance(scale, torch.Tensor):
|
||||
assert scale.device.type == "cpu"
|
||||
scale = scale.item()
|
||||
@@ -165,14 +160,14 @@ class _softmax(torch.autograd.Function):
|
||||
# enqueue kernel
|
||||
out = torch.empty_like(a)
|
||||
_blocksparse_softmax_fwd[grid](
|
||||
out, a, a.stride(0), lut,
|
||||
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn
|
||||
scale,
|
||||
is_causal,
|
||||
BLOCK_SIZE=block,
|
||||
ROW_SIZE=next_power_of_2(maxlut),
|
||||
IS_DENSE=is_dense,
|
||||
num_warps=num_warps(maxlut)
|
||||
out, a, a.stride(0), lut, #
|
||||
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn#
|
||||
scale, #
|
||||
is_causal, #
|
||||
BLOCK_SIZE=block, #
|
||||
ROW_SIZE=next_power_of_2(maxlut), #
|
||||
IS_DENSE=is_dense, #
|
||||
num_warps=num_warps(maxlut) #
|
||||
)
|
||||
# save to context
|
||||
# ctx.mark_dirty(x)
|
||||
@@ -201,28 +196,23 @@ class _softmax(torch.autograd.Function):
|
||||
grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M)
|
||||
da = torch.empty_like(dout)
|
||||
_blocksparse_softmax_bwd[grid](
|
||||
da, da.stride(0),
|
||||
dout, dout.stride(0),
|
||||
out, out.stride(0),
|
||||
ctx.scale,
|
||||
lut,
|
||||
dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2],
|
||||
ctx.is_causal,
|
||||
BLOCK_SIZE=ctx.block,
|
||||
ROW_SIZE=next_power_of_2(ctx.maxlut),
|
||||
IS_DENSE=ctx.is_dense,
|
||||
num_warps=num_warps(ctx.maxlut)
|
||||
da, da.stride(0), #
|
||||
dout, dout.stride(0), #
|
||||
out, out.stride(0), #
|
||||
ctx.scale, #
|
||||
lut, #
|
||||
dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2], #
|
||||
ctx.is_causal, #
|
||||
BLOCK_SIZE=ctx.block, #
|
||||
ROW_SIZE=next_power_of_2(ctx.maxlut), #
|
||||
IS_DENSE=ctx.is_dense, #
|
||||
num_warps=num_warps(ctx.maxlut) #
|
||||
)
|
||||
return (da, None, None, dr, None,
|
||||
None, None, None, None, None,
|
||||
None,
|
||||
None, None, None,
|
||||
None,
|
||||
None, None, None
|
||||
)
|
||||
return (da, None, None, dr, None, None, None, None, None, None, None, None, None, None, None, None, None, None)
|
||||
|
||||
|
||||
class softmax:
|
||||
|
||||
def __init__(self, layout, block, device, is_dense=False):
|
||||
self.spdims = layout.shape
|
||||
self.layout = layout
|
||||
@@ -233,8 +223,6 @@ class softmax:
|
||||
def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False):
|
||||
if rel_logits is not None and rel_logits.dtype != a.dtype:
|
||||
raise ValueError(f"relative position embedding must be {a.dtype}")
|
||||
a = _softmax.apply(
|
||||
a, scale, rel_logits, is_causal,
|
||||
self.spdims, self.block, self.lut, self.maxlut, self.is_dense,
|
||||
)
|
||||
a = _softmax.apply(a, scale, rel_logits, is_causal, self.spdims, self.block, self.lut, self.maxlut,
|
||||
self.is_dense)
|
||||
return a
|
||||
|
||||
@@ -59,6 +59,7 @@ def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
|
||||
|
||||
|
||||
class _cross_entropy(torch.autograd.Function):
|
||||
|
||||
@classmethod
|
||||
def forward(cls, ctx, logits, indices):
|
||||
# make sure we can use triton
|
||||
|
||||
@@ -15,20 +15,19 @@ from .. import language as tl
|
||||
|
||||
|
||||
@jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
L,
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, H, N_CTX,
|
||||
Z_H_N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
):
|
||||
def _fwd_kernel(Q, K, V, sm_scale, #
|
||||
L, #
|
||||
Out, #
|
||||
stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vn, stride_vk, #
|
||||
stride_oz, stride_oh, stride_om, stride_on, #
|
||||
Z, H, N_CTX, #
|
||||
Z_H_N_CTX, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
|
||||
BLOCK_N: tl.constexpr, #
|
||||
IS_CAUSAL: tl.constexpr #
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
qvk_offset = off_hz * stride_qh
|
||||
@@ -40,7 +39,7 @@ def _fwd_kernel(
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, vk_offset),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
order=(0, 1)
|
||||
order=(0, 1),
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V,
|
||||
@@ -48,7 +47,7 @@ def _fwd_kernel(
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(vk_offset, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
order=(1, 0),
|
||||
)
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
@@ -104,7 +103,7 @@ def _fwd_kernel(
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(vk_offset + start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
order=(1, 0),
|
||||
)
|
||||
# O_ptrs = Out + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk
|
||||
tl.store(O_block_ptr, acc.to(K.dtype.element_ty))
|
||||
@@ -112,9 +111,11 @@ def _fwd_kernel(
|
||||
|
||||
@jit
|
||||
def _bwd_preprocess(
|
||||
Out, DO,
|
||||
Out,
|
||||
DO,
|
||||
Delta,
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
D_HEAD: tl.constexpr,
|
||||
):
|
||||
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, D_HEAD)
|
||||
@@ -128,40 +129,48 @@ def _bwd_preprocess(
|
||||
|
||||
|
||||
@jit
|
||||
def _bwd_kernel_one_col_block(
|
||||
Q, K, V, sm_scale, qk_scale,
|
||||
Out, DO,
|
||||
DQ, DK, DV,
|
||||
L,
|
||||
D,
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
Z, H, N_CTX,
|
||||
off_hz, start_n, num_block,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
SEQUENCE_PARALLEL: tl.constexpr,
|
||||
CAUSAL: tl.constexpr,
|
||||
MMA_V3: tl.constexpr
|
||||
):
|
||||
if SEQUENCE_PARALLEL:
|
||||
DQ += stride_dqa.to(tl.int64) * start_n
|
||||
def _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, #
|
||||
Out, DO, #
|
||||
DQ, DK, DV, #
|
||||
L, #
|
||||
D, #
|
||||
Q_block_ptr, K_block_ptr, V_block_ptr, #
|
||||
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vn, stride_vk, #
|
||||
Z, H, N_CTX, #
|
||||
off_h, off_z, off_hz, start_n, num_block, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
|
||||
BLOCK_N: tl.constexpr, #
|
||||
SEQUENCE_PARALLEL: tl.constexpr, #
|
||||
CAUSAL: tl.constexpr, #
|
||||
MMA_V3: tl.constexpr #
|
||||
):
|
||||
if CAUSAL:
|
||||
lo = start_n * BLOCK_M
|
||||
else:
|
||||
lo = 0
|
||||
|
||||
Q_offset = (off_z * stride_qz + off_h * stride_qh) // stride_qm
|
||||
DQ_offset = off_z * stride_qz + off_h * stride_qh
|
||||
K_offset = (off_z * stride_kz + off_h * stride_kh) // stride_kn
|
||||
V_offset = (off_z * stride_vz + off_h * stride_vh) // stride_vn
|
||||
if SEQUENCE_PARALLEL:
|
||||
DQ_offset += stride_dqa.to(tl.int64) * start_n
|
||||
DQ_offset = DQ_offset // stride_qm
|
||||
|
||||
Q_block_ptr = tl.advance(Q_block_ptr, (lo + Q_offset, 0))
|
||||
K_block_ptr = tl.advance(K_block_ptr, (start_n * BLOCK_M + K_offset, 0))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (start_n * BLOCK_M + V_offset, 0))
|
||||
DO_block_ptr = tl.advance(DO_block_ptr, (lo + Q_offset, 0))
|
||||
DQ_block_ptr = tl.advance(DQ_block_ptr, (lo + DQ_offset, 0))
|
||||
DK_block_ptr = tl.advance(DK_block_ptr, (start_n * BLOCK_M + K_offset, 0))
|
||||
DV_block_ptr = tl.advance(DV_block_ptr, (start_n * BLOCK_M + V_offset, 0))
|
||||
|
||||
# initialize row/col offsets
|
||||
offs_qm = lo + tl.arange(0, BLOCK_M)
|
||||
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_m = tl.arange(0, BLOCK_N)
|
||||
offs_k = tl.arange(0, BLOCK_DMODEL)
|
||||
# initialize pointers to value-like data
|
||||
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
||||
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk)
|
||||
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
# pointer to row-wise quantities in value-like data
|
||||
D_ptrs = D + off_hz * N_CTX
|
||||
l_ptrs = L + off_hz * N_CTX
|
||||
@@ -169,17 +178,17 @@ def _bwd_kernel_one_col_block(
|
||||
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# k and v stay in SRAM throughout
|
||||
k = tl.load(k_ptrs)
|
||||
v = tl.load(v_ptrs)
|
||||
k = tl.load(K_block_ptr)
|
||||
v = tl.load(V_block_ptr)
|
||||
# loop over rows
|
||||
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
|
||||
offs_m_curr = start_m + offs_m
|
||||
# load q, k, v, do on-chip
|
||||
q = tl.load(q_ptrs)
|
||||
q = tl.load(Q_block_ptr)
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
# NOTE: `do` is pre-divided by `l`; no normalization here
|
||||
if CAUSAL:
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf"))
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.0), float("-inf"))
|
||||
else:
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, tl.trans(k))
|
||||
@@ -187,7 +196,7 @@ def _bwd_kernel_one_col_block(
|
||||
l_i = tl.load(l_ptrs + offs_m_curr)
|
||||
p = tl.math.exp2(qk - l_i[:, None])
|
||||
# compute dv
|
||||
do = tl.load(do_ptrs)
|
||||
do = tl.load(DO_block_ptr)
|
||||
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do, allow_tf32=True)
|
||||
# compute dp = dot(v, do)
|
||||
Di = tl.load(D_ptrs + offs_m_curr)
|
||||
@@ -199,97 +208,156 @@ def _bwd_kernel_one_col_block(
|
||||
dk += tl.dot(tl.trans(ds), q, allow_tf32=True)
|
||||
# compute dq
|
||||
if not SEQUENCE_PARALLEL:
|
||||
dq = tl.load(dq_ptrs)
|
||||
dq = tl.load(DQ_block_ptr)
|
||||
dq += tl.dot(ds, k, allow_tf32=True)
|
||||
tl.store(dq_ptrs, dq)
|
||||
tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty))
|
||||
elif SEQUENCE_PARALLEL:
|
||||
if MMA_V3:
|
||||
dq = tl.dot(ds, k, allow_tf32=True)
|
||||
else:
|
||||
# not work with mma v3, becuase M % 64 != 0
|
||||
dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds), allow_tf32=True))
|
||||
tl.store(dq_ptrs, dq)
|
||||
tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty))
|
||||
|
||||
# increment pointers
|
||||
dq_ptrs += BLOCK_M * stride_qm
|
||||
q_ptrs += BLOCK_M * stride_qm
|
||||
do_ptrs += BLOCK_M * stride_qm
|
||||
DQ_block_ptr = tl.advance(DQ_block_ptr, (BLOCK_M, 0))
|
||||
Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_M, 0))
|
||||
DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_M, 0))
|
||||
# write-back
|
||||
dv_ptrs = DV + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk)
|
||||
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
||||
tl.store(dv_ptrs, dv)
|
||||
tl.store(dk_ptrs, dk)
|
||||
tl.store(DV_block_ptr, dv.to(V.dtype.element_ty))
|
||||
tl.store(DK_block_ptr, dk.to(K.dtype.element_ty))
|
||||
|
||||
|
||||
@jit
|
||||
def _bwd_kernel(
|
||||
# fmt: off
|
||||
Q, K, V, sm_scale,
|
||||
Out, DO,
|
||||
DQ, DK, DV,
|
||||
L,
|
||||
D,
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
Z, H, N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
SEQUENCE_PARALLEL: tl.constexpr,
|
||||
CAUSAL: tl.constexpr,
|
||||
MMA_V3: tl.constexpr
|
||||
# fmt: on
|
||||
):
|
||||
def _bwd_kernel(Q, K, V, sm_scale, #
|
||||
Out, DO, #
|
||||
DQ, DK, DV, #
|
||||
L, #
|
||||
D, #
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vn, stride_vk, #
|
||||
Z, H, N_CTX, #
|
||||
Z_H_N_CTX, #
|
||||
SQ_Z_H_N_CTX, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
|
||||
BLOCK_N: tl.constexpr, #
|
||||
SEQUENCE_PARALLEL: tl.constexpr, #
|
||||
CAUSAL: tl.constexpr, #
|
||||
MMA_V3: tl.constexpr #
|
||||
):
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
off_hz = tl.program_id(0)
|
||||
off_z = off_hz // H
|
||||
off_h = off_hz % H
|
||||
# offset pointers for batch/head
|
||||
Q += off_z * stride_qz + off_h * stride_qh
|
||||
K += off_z * stride_kz + off_h * stride_kh
|
||||
V += off_z * stride_vz + off_h * stride_vh
|
||||
DO += off_z * stride_qz + off_h * stride_qh
|
||||
DQ += off_z * stride_qz + off_h * stride_qh
|
||||
DK += off_z * stride_kz + off_h * stride_kh
|
||||
DV += off_z * stride_vz + off_h * stride_vh
|
||||
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q,
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K,
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_kn, stride_kk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V,
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
DO_block_ptr = tl.make_block_ptr(
|
||||
base=DO,
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
if SEQUENCE_PARALLEL:
|
||||
DQ_block_ptr = tl.make_block_ptr(
|
||||
base=DQ,
|
||||
shape=(SQ_Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
else:
|
||||
DQ_block_ptr = tl.make_block_ptr(
|
||||
base=DQ,
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
|
||||
DK_block_ptr = tl.make_block_ptr(
|
||||
base=DK,
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_kn, stride_kk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
DV_block_ptr = tl.make_block_ptr(
|
||||
base=DV,
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
|
||||
num_block_n = tl.cdiv(N_CTX, BLOCK_N)
|
||||
if not SEQUENCE_PARALLEL:
|
||||
for start_n in range(0, num_block_n):
|
||||
_bwd_kernel_one_col_block(
|
||||
Q, K, V, sm_scale, qk_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
L,
|
||||
D,
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
Z, H, N_CTX,
|
||||
off_hz, start_n, num_block_n,
|
||||
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_N=BLOCK_N,
|
||||
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
|
||||
CAUSAL=CAUSAL,
|
||||
MMA_V3=MMA_V3
|
||||
)
|
||||
_bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, #
|
||||
DQ, DK, DV, #
|
||||
L, #
|
||||
D, #
|
||||
Q_block_ptr, K_block_ptr, V_block_ptr, #
|
||||
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vn, stride_vk, #
|
||||
Z, H, N_CTX, #
|
||||
off_h, off_z, off_hz, start_n, num_block_n, #
|
||||
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, #
|
||||
BLOCK_N=BLOCK_N, #
|
||||
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, #
|
||||
CAUSAL=CAUSAL, #
|
||||
MMA_V3=MMA_V3 #
|
||||
)
|
||||
else:
|
||||
start_n = tl.program_id(1)
|
||||
_bwd_kernel_one_col_block(
|
||||
Q, K, V, sm_scale, qk_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
L,
|
||||
D,
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
Z, H, N_CTX,
|
||||
off_hz, start_n, num_block_n,
|
||||
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_N=BLOCK_N,
|
||||
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
|
||||
CAUSAL=CAUSAL,
|
||||
MMA_V3=MMA_V3
|
||||
)
|
||||
_bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, #
|
||||
DQ, DK, DV, #
|
||||
L, #
|
||||
D, #
|
||||
Q_block_ptr, K_block_ptr, V_block_ptr, #
|
||||
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vn, stride_vk, #
|
||||
Z, H, N_CTX, #
|
||||
off_h, off_z, off_hz, start_n, num_block_n, #
|
||||
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, #
|
||||
BLOCK_N=BLOCK_N, #
|
||||
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, #
|
||||
CAUSAL=CAUSAL, #
|
||||
MMA_V3=MMA_V3 #
|
||||
)
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
@@ -315,19 +383,20 @@ class _attention(torch.autograd.Function):
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
_fwd_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
L,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
q.shape[0] * q.shape[1] * q.shape[2],
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
|
||||
IS_CAUSAL=causal,
|
||||
num_warps=num_warps,
|
||||
num_stages=4)
|
||||
q, k, v, sm_scale, #
|
||||
L, #
|
||||
o, #
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
|
||||
q.shape[0], q.shape[1], q.shape[2], #
|
||||
q.shape[0] * q.shape[1] * q.shape[2], #
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, #
|
||||
IS_CAUSAL=causal, #
|
||||
num_warps=num_warps, #
|
||||
num_stages=4 #
|
||||
)
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, L)
|
||||
ctx.grid = grid
|
||||
@@ -348,35 +417,39 @@ class _attention(torch.autograd.Function):
|
||||
do = do.contiguous()
|
||||
if sequence_parallel:
|
||||
replicas = cdiv(seq_len_kv, BLOCK)
|
||||
new_dq_shape = (replicas,) + q.shape
|
||||
new_dq_shape = (replicas, ) + q.shape
|
||||
dq = torch.zeros(new_dq_shape, device=q.device, dtype=q.dtype)
|
||||
else:
|
||||
dq = torch.zeros_like(q, dtype=torch.float32)
|
||||
dq = torch.zeros_like(q, dtype=q.dtype)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
delta = torch.empty_like(L)
|
||||
_bwd_preprocess[(cdiv(q.shape[2], BLOCK) * ctx.grid[1], )](
|
||||
o, do,
|
||||
o,
|
||||
do,
|
||||
delta,
|
||||
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
BLOCK_M=BLOCK,
|
||||
D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
_bwd_kernel[(ctx.grid[1], cdiv(seq_len_kv, BLOCK) if sequence_parallel else 1)](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do,
|
||||
dq, dk, dv,
|
||||
L,
|
||||
delta,
|
||||
o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
|
||||
SEQUENCE_PARALLEL=sequence_parallel,
|
||||
CAUSAL=ctx.causal,
|
||||
MMA_V3=MMA_V3,
|
||||
num_warps=8,
|
||||
num_stages=1,
|
||||
q, k, v, ctx.sm_scale, #
|
||||
o, do, #
|
||||
dq, dk, dv, #
|
||||
L, #
|
||||
delta, #
|
||||
o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
|
||||
q.shape[0], q.shape[1], q.shape[2], #
|
||||
q.shape[0] * q.shape[1] * q.shape[2], #
|
||||
cdiv(seq_len_kv, BLOCK) * q.shape[0] * q.shape[1] * q.shape[2], #
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK, #
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, #
|
||||
SEQUENCE_PARALLEL=sequence_parallel, #
|
||||
CAUSAL=ctx.causal, #
|
||||
MMA_V3=MMA_V3, #
|
||||
num_warps=8, #
|
||||
num_stages=1 #
|
||||
)
|
||||
|
||||
if len(dq.shape) == 5:
|
||||
|
||||
@@ -37,8 +37,9 @@ def get_configs_io_bound():
|
||||
num_stages=num_stages, num_warps=num_warps))
|
||||
# split_k
|
||||
for split_k in [2, 4, 8, 16]:
|
||||
configs.append(Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
configs.append(
|
||||
Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
return configs
|
||||
|
||||
|
||||
@@ -69,22 +70,22 @@ def get_configs_io_bound():
|
||||
prune_configs_by={
|
||||
'early_config_prune': early_config_prune,
|
||||
'perf_model': estimate_matmul_time,
|
||||
'top_k': 10
|
||||
'top_k': 10,
|
||||
},
|
||||
)
|
||||
@heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||
})
|
||||
@jit
|
||||
def _kernel(A, B, C, M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
dot_out_dtype: tl.constexpr,
|
||||
allow_tf32: tl.constexpr,
|
||||
fp8_fast_accum: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr
|
||||
def _kernel(A, B, C, M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
dot_out_dtype: tl.constexpr, #
|
||||
allow_tf32: tl.constexpr, #
|
||||
fp8_fast_accum: tl.constexpr, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr #
|
||||
):
|
||||
# matrix multiplication
|
||||
pid = tl.program_id(0)
|
||||
@@ -184,14 +185,15 @@ class _matmul(torch.autograd.Function):
|
||||
ab_dtype = False
|
||||
# launch kernel
|
||||
grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||
_kernel[grid](a, b, c, M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
dot_out_dtype=dot_out_dtype,
|
||||
allow_tf32=allow_tf32,
|
||||
fp8_fast_accum=fp8_fast_accum,
|
||||
GROUP_M=8, AB_DTYPE=ab_dtype)
|
||||
_kernel[grid](
|
||||
a, b, c, M, N, K, #
|
||||
a.stride(0), a.stride(1), #
|
||||
b.stride(0), b.stride(1), #
|
||||
c.stride(0), c.stride(1), #
|
||||
dot_out_dtype=dot_out_dtype, #
|
||||
allow_tf32=allow_tf32, #
|
||||
fp8_fast_accum=fp8_fast_accum, #
|
||||
GROUP_M=8, AB_DTYPE=ab_dtype)
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -5,8 +5,7 @@ import torch
|
||||
from .. import cdiv
|
||||
from .._C.libtriton.triton import runtime
|
||||
from ..runtime import driver
|
||||
from ..testing import (get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops,
|
||||
nvsmi)
|
||||
from ..testing import (get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops, nvsmi)
|
||||
|
||||
|
||||
def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
@@ -14,7 +13,8 @@ def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
total_warps = num_ctas * min(num_warps, 4)
|
||||
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
|
||||
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
|
||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, cur_sm_clock, backend, device)
|
||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(
|
||||
dtype, cur_sm_clock, backend, device)
|
||||
return tflops
|
||||
|
||||
|
||||
@@ -35,12 +35,12 @@ def get_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
|
||||
|
||||
def estimate_matmul_time(
|
||||
# backend, device,
|
||||
num_warps, num_stages,
|
||||
A, B, C,
|
||||
M, N, K,
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K,
|
||||
debug=False, **kwargs
|
||||
# backend, device,
|
||||
num_warps, num_stages, #
|
||||
A, B, C, #
|
||||
M, N, K, #
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, #
|
||||
debug=False, **kwargs #
|
||||
):
|
||||
''' return estimated running time in ms
|
||||
= max(compute, loading) + store '''
|
||||
@@ -149,8 +149,9 @@ def early_config_prune(configs, named_args):
|
||||
optimal_num_stages = ldgsts_latency / mma_cycles
|
||||
|
||||
# nearest stages, prefer large #stages
|
||||
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages)
|
||||
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
|
||||
nearest = heapq.nsmallest(
|
||||
2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages)
|
||||
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
|
||||
|
||||
for n in nearest:
|
||||
pruned_configs.append(n[0])
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune,
|
||||
heuristics)
|
||||
from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune, heuristics)
|
||||
from .driver import driver
|
||||
from .jit import (JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret,
|
||||
version_key)
|
||||
from .jit import JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret
|
||||
|
||||
__all__ = [
|
||||
"driver",
|
||||
@@ -12,7 +10,6 @@ __all__ = [
|
||||
"heuristics",
|
||||
"JITFunction",
|
||||
"KernelInterface",
|
||||
"version_key",
|
||||
"reinterpret",
|
||||
"TensorWrapper",
|
||||
"OutOfResources",
|
||||
|
||||
@@ -9,11 +9,10 @@ from .jit import KernelInterface
|
||||
|
||||
|
||||
class OutOfResources(Exception):
|
||||
|
||||
def __init__(self, required, limit, name):
|
||||
self.message = f'out of resource: {name}, '\
|
||||
f'Required: {required}, '\
|
||||
f'Hardware limit: {limit}'
|
||||
self.message += '. Reducing block sizes or `num_stages` may help.'
|
||||
self.message = (f"out of resource: {name}, Required: {required}, Hardware limit: {limit}. " +
|
||||
"Reducing block sizes or `num_stages` may help.")
|
||||
self.required = required
|
||||
self.limit = limit
|
||||
self.name = name
|
||||
@@ -25,38 +24,77 @@ class OutOfResources(Exception):
|
||||
|
||||
|
||||
class Autotuner(KernelInterface):
|
||||
<<<<<<< HEAD
|
||||
def __init__(self, fn, arg_names, configs, key, verbose, reset_to_zero, prune_configs_by: Dict = None, warmup=25, rep=100):
|
||||
'''
|
||||
=======
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fn,
|
||||
arg_names,
|
||||
configs,
|
||||
key,
|
||||
reset_to_zero,
|
||||
restore_value,
|
||||
prune_configs_by: Dict = None,
|
||||
warmup=25,
|
||||
rep=100,
|
||||
):
|
||||
"""
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
|
||||
'''
|
||||
"""
|
||||
if not configs:
|
||||
self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)]
|
||||
else:
|
||||
self.configs = configs
|
||||
self.key_idx = [arg_names.index(k) for k in key]
|
||||
self.cache = {}
|
||||
# hook to reset all required tensor to zeros before relaunching a kernel
|
||||
self.hook = lambda args: 0
|
||||
self.arg_names = arg_names
|
||||
|
||||
# Reset to zero or restore values
|
||||
self.reset_idx = []
|
||||
if reset_to_zero is not None:
|
||||
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
|
||||
self.restore_idx = []
|
||||
if restore_value is not None:
|
||||
self.restore_idx = [arg_names.index(k) for k in restore_value]
|
||||
|
||||
def _hook(args):
|
||||
# Hook to reset or restore for required tensors
|
||||
self.pre_hook = lambda args, reset_only=False: 0
|
||||
self.post_hook = lambda args: 0
|
||||
if len(self.reset_idx) > 0 or len(self.restore_idx) > 0:
|
||||
|
||||
def _pre_hook(args, reset_only=False):
|
||||
for i in self.reset_idx:
|
||||
args[i].zero_()
|
||||
self.hook = _hook
|
||||
self.arg_names = arg_names
|
||||
# prune configs
|
||||
if not reset_only:
|
||||
self.restore_copies = [args[i].clone() for i in self.restore_idx]
|
||||
|
||||
self.pre_hook = _pre_hook
|
||||
if len(self.restore_idx) > 0:
|
||||
|
||||
def _post_hook(args):
|
||||
for i, j in enumerate(self.restore_idx):
|
||||
args[j].copy_(self.restore_copies[i])
|
||||
self.restore_copies = []
|
||||
|
||||
self.post_hook = _post_hook
|
||||
|
||||
# Prune configs
|
||||
if prune_configs_by:
|
||||
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
|
||||
if 'early_config_prune' in prune_configs_by:
|
||||
early_config_prune = prune_configs_by['early_config_prune']
|
||||
perf_model, top_k = prune_configs_by["perf_model"], prune_configs_by["top_k"]
|
||||
if "early_config_prune" in prune_configs_by:
|
||||
early_config_prune = prune_configs_by["early_config_prune"]
|
||||
else:
|
||||
perf_model, top_k, early_config_prune = None, None, None
|
||||
self.perf_model, self.configs_top_k = perf_model, top_k
|
||||
self.early_config_prune = early_config_prune
|
||||
|
||||
self.fn = fn
|
||||
self.warmup = warmup
|
||||
self.rep = rep
|
||||
@@ -67,10 +105,8 @@ class Autotuner(KernelInterface):
|
||||
# as kwargs and by the autotuner
|
||||
conflicts = meta.keys() & config.kwargs.keys()
|
||||
if conflicts:
|
||||
raise ValueError(
|
||||
f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
||||
" Make sure that you don't re-define auto-tuned symbols."
|
||||
)
|
||||
raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
||||
" Make sure that you don't re-define auto-tuned symbols.")
|
||||
# augment meta-parameters with tunable ones
|
||||
current = dict(meta, **config.kwargs)
|
||||
full_nargs = {**self.nargs, **current}
|
||||
@@ -78,16 +114,22 @@ class Autotuner(KernelInterface):
|
||||
def kernel_call():
|
||||
if config.pre_hook:
|
||||
config.pre_hook(full_nargs)
|
||||
self.hook(args)
|
||||
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization,
|
||||
# enable_persistent=False,
|
||||
**current)
|
||||
self.pre_hook(args)
|
||||
self.fn.run(
|
||||
*args,
|
||||
num_warps=config.num_warps,
|
||||
num_stages=config.num_stages,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization,
|
||||
# enable_persistent=False,
|
||||
**current,
|
||||
)
|
||||
self.post_hook(args)
|
||||
|
||||
try:
|
||||
return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
|
||||
except OutOfResources:
|
||||
return [float('inf'), float('inf'), float('inf')]
|
||||
return [float("inf"), float("inf"), float("inf")]
|
||||
|
||||
def get_best_config(self):
|
||||
return self.best_config
|
||||
@@ -110,12 +152,11 @@ class Autotuner(KernelInterface):
|
||||
# prune configs
|
||||
pruned_configs = self.prune_configs(kwargs)
|
||||
bench_start = time.time()
|
||||
timings = {config: self._bench(*args, config=config, **kwargs)
|
||||
for config in pruned_configs}
|
||||
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
|
||||
bench_end = time.time()
|
||||
self.bench_time = bench_end - bench_start
|
||||
self.cache[key] = builtins.min(timings, key=timings.get)
|
||||
self.hook(args)
|
||||
self.pre_hook(args, reset_only=True)
|
||||
self.configs_timings = timings
|
||||
if self.verbose:
|
||||
print(str(key) + ": " + str(self.cache[key]))
|
||||
@@ -126,9 +167,15 @@ class Autotuner(KernelInterface):
|
||||
full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs}
|
||||
if config.pre_hook is not None:
|
||||
config.pre_hook(full_nargs)
|
||||
ret = self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization, **kwargs, **config.kwargs)
|
||||
ret = self.fn.run(
|
||||
*args,
|
||||
num_warps=config.num_warps,
|
||||
num_stages=config.num_stages,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization,
|
||||
**kwargs,
|
||||
**config.kwargs,
|
||||
)
|
||||
self.nargs = None
|
||||
return ret
|
||||
|
||||
@@ -142,17 +189,20 @@ class Autotuner(KernelInterface):
|
||||
top_k = int(len(self.configs) * top_k)
|
||||
if len(pruned_configs) > top_k:
|
||||
est_timing = {
|
||||
config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages,
|
||||
num_warps=config.num_warps,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization,
|
||||
enable_persistent=config.enable_persistent)
|
||||
config:
|
||||
self.perf_model(
|
||||
**self.nargs,
|
||||
**kwargs,
|
||||
**config.kwargs,
|
||||
num_stages=config.num_stages,
|
||||
num_warps=config.num_warps,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization,
|
||||
enable_persistent=config.enable_persistent,
|
||||
)
|
||||
for config in pruned_configs
|
||||
}
|
||||
pruned_configs = sorted(
|
||||
est_timing.keys(),
|
||||
key=lambda x: est_timing[x])[
|
||||
:top_k]
|
||||
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
||||
return pruned_configs
|
||||
|
||||
def warmup(self, *args, **kwargs):
|
||||
@@ -195,13 +245,14 @@ class Config:
|
||||
self.num_ctas = num_ctas
|
||||
self.num_stages = num_stages
|
||||
self.enable_warp_specialization = enable_warp_specialization
|
||||
# TODO[shuhaoj]: May make enable_persistent configurable in future if necessay.
|
||||
# TODO[shuhaoj]: May make enable_persistent configurable in future if necessary.
|
||||
self.enable_persistent = False
|
||||
self.pre_hook = pre_hook
|
||||
|
||||
def __str__(self):
|
||||
res = []
|
||||
for k, v in self.kwargs.items():
|
||||
<<<<<<< HEAD
|
||||
res.append(f'{k}: {v}')
|
||||
res.append(f'num_warps: {self.num_warps}')
|
||||
## Comment out Hopper specific parameters
|
||||
@@ -214,6 +265,18 @@ class Config:
|
||||
|
||||
|
||||
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=False, warmup=25, rep=100):
|
||||
=======
|
||||
res.append(f"{k}: {v}")
|
||||
res.append(f"num_warps: {self.num_warps}")
|
||||
res.append(f"num_ctas: {self.num_ctas}")
|
||||
res.append(f"num_stages: {self.num_stages}")
|
||||
res.append(f"enable_warp_specialization: {self.enable_warp_specialization}")
|
||||
res.append(f"enable_persistent: {self.enable_persistent}")
|
||||
return ", ".join(res)
|
||||
|
||||
|
||||
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, warmup=25, rep=100):
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
"""
|
||||
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
||||
|
||||
@@ -244,6 +307,8 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=Fa
|
||||
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
|
||||
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
||||
:type reset_to_zero: list[str]
|
||||
:param restore_value: a list of argument names whose value will be restored after evaluating any configs.
|
||||
:type restore_value: list[str]
|
||||
:param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25.
|
||||
:type warmup: int
|
||||
:param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100.
|
||||
@@ -251,8 +316,13 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=Fa
|
||||
:param verbose: a boolean that controls whether the best_config for each key is printed
|
||||
:type verbose: bool
|
||||
"""
|
||||
|
||||
def decorator(fn):
|
||||
<<<<<<< HEAD
|
||||
return Autotuner(fn, fn.arg_names, configs, key, verbose, reset_to_zero, prune_configs_by, warmup, rep)
|
||||
=======
|
||||
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, prune_configs_by, warmup, rep)
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -286,6 +356,7 @@ def heuristics(values):
|
||||
each such function takes a list of positional arguments as input.
|
||||
:type values: dict[str, Callable[[list[Any]], Any]]
|
||||
"""
|
||||
|
||||
def decorator(fn):
|
||||
return Heuristics(fn, fn.arg_names, values)
|
||||
|
||||
|
||||
@@ -1,27 +1,42 @@
|
||||
#include "cuda.h"
|
||||
#include <dlfcn.h>
|
||||
#include <stdbool.h>
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
|
||||
static inline void gpuAssert(CUresult code, const char *file, int line) {
|
||||
if (code != CUDA_SUCCESS) {
|
||||
const char *prefix = "Triton Error [CUDA]: ";
|
||||
const char *str;
|
||||
cuGetErrorString(code, &str);
|
||||
char err[1024] = {0};
|
||||
strcat(err, prefix);
|
||||
strcat(err, str);
|
||||
PyGILState_STATE gil_state;
|
||||
gil_state = PyGILState_Ensure();
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
PyGILState_Release(gil_state);
|
||||
}
|
||||
// Raises a Python exception and returns false if code is not CUDA_SUCCESS.
|
||||
static bool gpuAssert(CUresult code, const char *file, int line) {
|
||||
if (code == CUDA_SUCCESS)
|
||||
return true;
|
||||
|
||||
const char *prefix = "Triton Error [CUDA]: ";
|
||||
const char *str;
|
||||
cuGetErrorString(code, &str);
|
||||
char err[1024] = {0};
|
||||
strcat(err, prefix);
|
||||
strcat(err, str);
|
||||
PyGILState_STATE gil_state;
|
||||
gil_state = PyGILState_Ensure();
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
PyGILState_Release(gil_state);
|
||||
return false;
|
||||
}
|
||||
|
||||
#define CUDA_CHECK(ans) \
|
||||
{ \
|
||||
{ gpuAssert((ans), __FILE__, __LINE__); } \
|
||||
}
|
||||
// To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block.
|
||||
#define CUDA_CHECK_AND_RETURN_NULL(ans) \
|
||||
do { \
|
||||
if (!gpuAssert((ans), __FILE__, __LINE__)) \
|
||||
return NULL; \
|
||||
} while (0)
|
||||
|
||||
// To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block.
|
||||
#define CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(ans) \
|
||||
do { \
|
||||
if (!gpuAssert((ans), __FILE__, __LINE__)) { \
|
||||
PyEval_RestoreThread(_save); \
|
||||
return NULL; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define ADD_ENUM_ITEM(value) \
|
||||
do { \
|
||||
@@ -200,16 +215,16 @@ static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
|
||||
int sm_clock_rate;
|
||||
int mem_clock_rate;
|
||||
int mem_bus_width;
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
|
||||
&max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
||||
device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
|
||||
&multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate,
|
||||
CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
|
||||
&sm_clock_rate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
|
||||
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
|
||||
&mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
|
||||
&mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device));
|
||||
|
||||
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem",
|
||||
@@ -237,33 +252,37 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
|
||||
CUcontext pctx = 0;
|
||||
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
CUDA_CHECK(cuCtxGetCurrent(&pctx));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&pctx));
|
||||
if (!pctx) {
|
||||
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
|
||||
CUDA_CHECK(cuCtxSetCurrent(pctx));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
||||
cuDevicePrimaryCtxRetain(&pctx, device));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(pctx));
|
||||
}
|
||||
|
||||
CUDA_CHECK(cuModuleLoadData(&mod, data));
|
||||
CUDA_CHECK(cuModuleGetFunction(&fun, mod, name));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuModuleLoadData(&mod, data));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
||||
cuModuleGetFunction(&fun, mod, name));
|
||||
// get allocated registers and spilled registers from the function
|
||||
CUDA_CHECK(cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
|
||||
CUDA_CHECK(
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
||||
cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
||||
cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
|
||||
n_spills /= 4;
|
||||
// set dynamic shared memory if necessary
|
||||
int shared_optin;
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute(
|
||||
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
||||
device));
|
||||
if (shared > 49152 && shared_optin > 49152) {
|
||||
CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
||||
cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
|
||||
int shared_total, shared_static;
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute(
|
||||
&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
|
||||
device));
|
||||
CUDA_CHECK(cuFuncGetAttribute(&shared_static,
|
||||
CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
|
||||
CUDA_CHECK(
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute(
|
||||
&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
||||
cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
shared_optin - shared_static));
|
||||
}
|
||||
@@ -286,7 +305,7 @@ static PyObject *memAlloc(PyObject *self, PyObject *args) {
|
||||
}
|
||||
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
CUDA_CHECK(cuMemAlloc(&dptr, bytesize));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuMemAlloc(&dptr, bytesize));
|
||||
Py_END_ALLOW_THREADS;
|
||||
|
||||
return PyLong_FromUnsignedLongLong((unsigned long long)dptr);
|
||||
@@ -307,7 +326,8 @@ static PyObject *memcpyHtoD(PyObject *self, PyObject *args) {
|
||||
srcHost = (const void *)srcHostPtr;
|
||||
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
CUDA_CHECK(cuMemcpyHtoD(dstDevice, srcHost, byteCount));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
||||
cuMemcpyHtoD(dstDevice, srcHost, byteCount));
|
||||
Py_END_ALLOW_THREADS;
|
||||
|
||||
Py_RETURN_NONE;
|
||||
@@ -321,7 +341,7 @@ static PyObject *memFree(PyObject *self, PyObject *args) {
|
||||
}
|
||||
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
CUDA_CHECK(cuMemFree(dptr));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuMemFree(dptr));
|
||||
Py_END_ALLOW_THREADS;
|
||||
|
||||
Py_RETURN_NONE;
|
||||
@@ -411,7 +431,7 @@ static PyObject *tensorMapEncodeTiled(PyObject *self, PyObject *args) {
|
||||
}
|
||||
// Call the function
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
CUDA_CHECK(cuTensorMapEncodeTiledHandle(
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuTensorMapEncodeTiledHandle(
|
||||
tensorMap, tensorDataType, tensorRank, globalAddress, globalDim,
|
||||
globalStrides, boxDim, elementStrides, interleave, swizzle, l2Promotion,
|
||||
oobFill));
|
||||
|
||||
@@ -19,6 +19,7 @@ def default_dump_dir():
|
||||
|
||||
|
||||
class CacheManager(ABC):
|
||||
|
||||
def __init__(self, key):
|
||||
pass
|
||||
|
||||
@@ -44,20 +45,21 @@ class CacheManager(ABC):
|
||||
|
||||
|
||||
class FileCacheManager(CacheManager):
|
||||
|
||||
def __init__(self, key, override=False, dump=False):
|
||||
self.key = key
|
||||
self.lock_path = None
|
||||
if (dump):
|
||||
if dump:
|
||||
self.cache_dir = default_dump_dir()
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
elif (override):
|
||||
elif override:
|
||||
self.cache_dir = default_override_dir()
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
else:
|
||||
# create cache directory if it doesn't exist
|
||||
self.cache_dir = os.getenv('TRITON_CACHE_DIR', "").strip() or default_cache_dir()
|
||||
self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
|
||||
if self.cache_dir:
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||
@@ -93,9 +95,8 @@ class FileCacheManager(CacheManager):
|
||||
result = {}
|
||||
for c in child_paths:
|
||||
p = self._make_path(c)
|
||||
if not os.path.exists(p):
|
||||
raise Exception(f"Group file {p} does not exist from group {grp_filename} ")
|
||||
result[c] = p
|
||||
if os.path.exists(p):
|
||||
result[c] = p
|
||||
return result
|
||||
|
||||
# Note a group of pushed files as being part of a group
|
||||
@@ -142,6 +143,7 @@ def get_cache_manager(key) -> CacheManager:
|
||||
|
||||
if user_cache_manager is not None and user_cache_manager != __cache_cls_nme:
|
||||
import importlib
|
||||
|
||||
module_path, clz_nme = user_cache_manager.split(":")
|
||||
module = importlib.import_module(module_path)
|
||||
__cache_cls = getattr(module, clz_nme)
|
||||
|
||||
@@ -9,7 +9,6 @@ from .cache import get_cache_manager
|
||||
|
||||
|
||||
class DriverBase(metaclass=abc.ABCMeta):
|
||||
|
||||
CUDA = 0
|
||||
HIP = 1
|
||||
|
||||
@@ -19,6 +18,8 @@ class DriverBase(metaclass=abc.ABCMeta):
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# CUDA
|
||||
# -----------------------------
|
||||
@@ -27,7 +28,7 @@ class DriverBase(metaclass=abc.ABCMeta):
|
||||
class CudaUtils(object):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(CudaUtils, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
@@ -47,6 +48,7 @@ class CudaUtils(object):
|
||||
with open(so, "rb") as f:
|
||||
cache_path = cache.put(f.read(), fname, binary=True)
|
||||
import importlib.util
|
||||
|
||||
spec = importlib.util.spec_from_file_location("cuda_utils", cache_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
@@ -66,7 +68,7 @@ class CudaUtils(object):
|
||||
class CudaDriver(DriverBase):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(CudaDriver, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
@@ -74,14 +76,16 @@ class CudaDriver(DriverBase):
|
||||
self.utils = CudaUtils()
|
||||
self.backend = self.CUDA
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# HIP
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class HIPUtils(object):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(HIPUtils, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
@@ -101,6 +105,7 @@ class HIPUtils(object):
|
||||
with open(so, "rb") as f:
|
||||
cache_path = cache.put(f.read(), fname, binary=True)
|
||||
import importlib.util
|
||||
|
||||
spec = importlib.util.spec_from_file_location("hip_utils", cache_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
@@ -111,7 +116,7 @@ class HIPUtils(object):
|
||||
class HIPDriver(DriverBase):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(HIPDriver, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
@@ -123,7 +128,7 @@ class HIPDriver(DriverBase):
|
||||
class UnsupportedDriver(DriverBase):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(UnsupportedDriver, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
@@ -131,12 +136,14 @@ class UnsupportedDriver(DriverBase):
|
||||
self.utils = None
|
||||
self.backend = None
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Driver
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class LazyProxy:
|
||||
|
||||
def __init__(self, init_fn):
|
||||
self._init_fn = init_fn
|
||||
self._obj = None
|
||||
@@ -150,7 +157,7 @@ class LazyProxy:
|
||||
return getattr(self._obj, name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name in ['_init_fn', '_obj']:
|
||||
if name in ["_init_fn", "_obj"]:
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
self._initialize_obj()
|
||||
@@ -172,6 +179,7 @@ class LazyProxy:
|
||||
|
||||
def initialize_driver():
|
||||
import torch
|
||||
|
||||
if torch.version.hip is not None:
|
||||
return HIPDriver()
|
||||
elif torch.cuda.is_available():
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
|
||||
class OutOfResources(Exception):
|
||||
|
||||
def __init__(self, required, limit, name):
|
||||
self.message = f'out of resource: {name}, '\
|
||||
f'Required: {required}, '\
|
||||
f'Hardware limit: {limit}'
|
||||
self.message += '. Reducing block sizes or `num_stages` may help.'
|
||||
self.message = f"out of resource: {name}, " f"Required: {required}, " f"Hardware limit: {limit}"
|
||||
self.message += ". Reducing block sizes or `num_stages` may help."
|
||||
self.required = required
|
||||
self.limit = limit
|
||||
self.name = name
|
||||
|
||||
@@ -74,11 +74,15 @@ class BlockPointerHandle:
|
||||
|
||||
|
||||
def wrap_ret(compute_ret_ty):
|
||||
|
||||
def wrapper(fn):
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
ret = fn(*args, **kwargs)
|
||||
return TensorHandle(ret.data, compute_ret_ty(*args, **kwargs))
|
||||
|
||||
return wrapped
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@@ -249,11 +253,13 @@ class Builder:
|
||||
# ternary functions
|
||||
def ternary_op(self, lhs, rhs, other, op):
|
||||
return TensorHandle(op(lhs.data, rhs.data, other.data), other.dtype)
|
||||
|
||||
create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where)
|
||||
|
||||
# unary functions
|
||||
def unary_op(self, arg, op):
|
||||
return TensorHandle(op(arg.data), arg.dtype)
|
||||
|
||||
create_exp = lambda self, arg: self.unary_op(arg, np.exp)
|
||||
create_cos = lambda self, arg: self.unary_op(arg, np.cos)
|
||||
create_sin = lambda self, arg: self.unary_op(arg, np.sin)
|
||||
@@ -279,7 +285,8 @@ class Builder:
|
||||
dtype_tt = ptr.dtype.element_ty
|
||||
return TensorHandle(ptr.data + (dtype_tt.primitive_bitwidth // 8) * offset.data.astype(np.uint64), ptr.dtype)
|
||||
|
||||
def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, is_volatile):
|
||||
def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy,
|
||||
is_volatile):
|
||||
ptrs, masks = ptr.materialize_pointers(boundary_check)
|
||||
assert padding_option is None
|
||||
other = None
|
||||
@@ -297,6 +304,7 @@ class Builder:
|
||||
|
||||
def create_int_to_ptr(self, val, dst_ty):
|
||||
return TensorHandle(val.data.astype(np.uint64), dst_ty)
|
||||
|
||||
# def create_cat(self, lhs, rhs):
|
||||
# pass
|
||||
|
||||
@@ -360,7 +368,10 @@ class Builder:
|
||||
|
||||
|
||||
def patch_attr(obj, name, member, builder):
|
||||
new_member = lambda *args, member=member, **kwargs: (member(*args, **{k: v for k, v in kwargs.items() if k != '_builder'}, _builder=builder))
|
||||
new_member = lambda *args, member=member, **kwargs: (member(*args, **
|
||||
{k: v
|
||||
for k, v in kwargs.items()
|
||||
if k != "_builder"}, _builder=builder))
|
||||
setattr(obj, name, new_member)
|
||||
|
||||
|
||||
@@ -384,8 +395,8 @@ def _patch_lang_core(lang, builder):
|
||||
def _new_reduce(input, axis, combine_fn):
|
||||
fn = combine_fn.fn.__name__
|
||||
mapping = {
|
||||
'maximum': np.max,
|
||||
'_sum_combine': np.sum,
|
||||
"maximum": np.max,
|
||||
"_sum_combine": np.sum,
|
||||
}
|
||||
ret = mapping[fn](input.handle.data, axis=axis)
|
||||
ret_type = tl.block_type(input.dtype, ret.shape)
|
||||
@@ -397,15 +408,16 @@ def _patch_lang_core(lang, builder):
|
||||
def _patch_lang_math(lang, builder):
|
||||
math = lang.math
|
||||
mapping = {
|
||||
'abs': 'abs',
|
||||
'acos': 'arccos',
|
||||
'asin': 'arcsin',
|
||||
'exp2': 'exp2',
|
||||
'log2': 'log2',
|
||||
'max': 'maximum',
|
||||
"abs": "abs",
|
||||
"acos": "arccos",
|
||||
"asin": "arcsin",
|
||||
"exp2": "exp2",
|
||||
"log2": "log2",
|
||||
"max": "maximum",
|
||||
}
|
||||
|
||||
def make_numpy(name):
|
||||
|
||||
def impl(*args, **kwargs):
|
||||
ret_type = args[0].type # TODO: incorrect
|
||||
ret_dtype = args[0].dtype # TODO: incorrect
|
||||
@@ -414,15 +426,18 @@ def _patch_lang_math(lang, builder):
|
||||
ret = getattr(np, mapping[name])(*args, **kwargs)
|
||||
ret = tl.core.tensor(TensorHandle(ret, ret_dtype), ret_type)
|
||||
return ret
|
||||
|
||||
return impl
|
||||
|
||||
def make_fallback(name):
|
||||
|
||||
def fallback(*args, **kwargs):
|
||||
raise NotImplementedError(f"""
|
||||
{name} not supported in interpreter mode: no known numpy implementation.
|
||||
If you think that {name} in fact does have a numpy implementation, please add it
|
||||
to the mapping in python/triton/interpreter/new_interpreter.py:_patch_lang_math.
|
||||
""")
|
||||
|
||||
return fallback
|
||||
|
||||
for name, member in inspect.getmembers(math):
|
||||
@@ -438,7 +453,7 @@ def _implicit_cvt(arg):
|
||||
ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
|
||||
handle = TensorHandle(np.array([arg], dtype=np.int32), ty)
|
||||
return tl.tensor(handle, ty)
|
||||
if hasattr(arg, 'data_ptr'):
|
||||
if hasattr(arg, "data_ptr"):
|
||||
ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
|
||||
handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty)
|
||||
return tl.tensor(handle, ty)
|
||||
@@ -453,28 +468,29 @@ def _unwrap(tensor):
|
||||
|
||||
builder = Builder()
|
||||
|
||||
RESERVED_KWS = ['num_warps', 'num_stages', 'num_ctas', 'enable_warp_specialization', 'enable_fp_fusion']
|
||||
RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_warp_specialization", "enable_fp_fusion"]
|
||||
|
||||
|
||||
class GridExecutor:
|
||||
|
||||
def __init__(self, fn, arg_names, grid):
|
||||
from .jit import _normalize_ty # TODO: modularize
|
||||
|
||||
self.fn = fn
|
||||
self.arg_names = arg_names
|
||||
self.grid = grid
|
||||
__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
|
||||
self.constexprs = [name for name in arg_names if __annotations__.get(name) == 'constexpr']
|
||||
self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"]
|
||||
|
||||
def _patch_lang(self, builder):
|
||||
lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]]
|
||||
assert len(lang) == 1, "triton.language must be visible from within jit'd function"
|
||||
_patch_lang_tensor(getattr(lang[0], 'tensor'), builder)
|
||||
_patch_lang_tensor(getattr(lang[0], "tensor"), builder)
|
||||
_patch_lang_core(lang[0], builder)
|
||||
_patch_lang_math(lang[0], builder)
|
||||
|
||||
def __call__(self, *args_dev, **kwargs):
|
||||
args_hst = [_unwrap(arg).cpu() if hasattr(arg, 'data_ptr') else arg for arg in args_dev]
|
||||
args_hst = [_unwrap(arg).cpu() if hasattr(arg, "data_ptr") else arg for arg in args_dev]
|
||||
# removes reserved keywords from kwargs
|
||||
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS}
|
||||
# remaps core language functions to interpreted ones
|
||||
@@ -486,7 +502,7 @@ class GridExecutor:
|
||||
# iterate through grid
|
||||
grid = self.grid(args) if callable(self.grid) else self.grid
|
||||
assert len(grid) <= 3
|
||||
grid = grid + (1,) * (3 - len(grid))
|
||||
grid = grid + (1, ) * (3 - len(grid))
|
||||
builder.set_grid_dim(*grid)
|
||||
for x in range(grid[0]):
|
||||
for y in range(grid[1]):
|
||||
@@ -495,7 +511,7 @@ class GridExecutor:
|
||||
self.fn(**args)
|
||||
# copy arguments back to propagate side-effects
|
||||
for arg_dev, arg_hst in zip(args_dev, args_hst):
|
||||
if hasattr(arg_dev, 'data_ptr'):
|
||||
if hasattr(arg_dev, "data_ptr"):
|
||||
_unwrap(arg_dev).copy_(arg_hst.to(arg_dev.device))
|
||||
|
||||
|
||||
@@ -504,17 +520,18 @@ class InterpretedFunction:
|
||||
def _patch_lang(self, builder):
|
||||
lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]]
|
||||
assert len(lang) == 1, "triton.language must be visible from within jit'd function"
|
||||
_patch_lang_tensor(getattr(lang[0], 'tensor'), builder)
|
||||
_patch_lang_tensor(getattr(lang[0], "tensor"), builder)
|
||||
_patch_lang_core(lang[0], builder)
|
||||
|
||||
def __init__(self, fn) -> None:
|
||||
self.fn = fn
|
||||
|
||||
def run(*args, **kwargs):
|
||||
grid = kwargs['grid']
|
||||
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS + ['grid']}
|
||||
grid = kwargs["grid"]
|
||||
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS + ["grid"]}
|
||||
|
||||
return GridExecutor(self.fn, self.arg_names, grid)(*args, **kwargs)
|
||||
|
||||
self.run = run
|
||||
signature = inspect.signature(fn)
|
||||
self.arg_names = [v.name for v in signature.parameters.values()]
|
||||
|
||||
@@ -5,48 +5,48 @@ import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import os
|
||||
import subprocess
|
||||
import textwrap
|
||||
from collections import defaultdict, namedtuple
|
||||
from typing import (Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast,
|
||||
overload)
|
||||
from functools import cached_property
|
||||
from typing import Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast, overload
|
||||
|
||||
from .._C.libtriton.triton import TMAInfos
|
||||
from ..common.backend import get_backend, path_to_ptxas
|
||||
from ..language.core import dtype
|
||||
from ..common.backend import get_backend, get_cuda_version_key
|
||||
from .interpreter import InterpretedFunction
|
||||
|
||||
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
TRITON_VERSION = "2.1.0"
|
||||
|
||||
|
||||
def get_cuda_stream(idx=None):
|
||||
if idx is None:
|
||||
idx = get_current_device()
|
||||
try:
|
||||
from torch._C import _cuda_getCurrentRawStream
|
||||
|
||||
return _cuda_getCurrentRawStream(idx)
|
||||
except ImportError:
|
||||
import torch
|
||||
|
||||
return torch.cuda.current_stream(idx).cuda_stream
|
||||
|
||||
|
||||
def get_current_device():
|
||||
import torch
|
||||
|
||||
return torch.cuda.current_device()
|
||||
|
||||
|
||||
def set_current_device(idx):
|
||||
import torch
|
||||
|
||||
torch.cuda.set_device(idx)
|
||||
|
||||
|
||||
def get_device_capability(idx):
|
||||
import torch
|
||||
|
||||
return torch.cuda.get_device_capability(idx)
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
T = TypeVar("T")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Dependencies Finder
|
||||
@@ -72,7 +72,8 @@ class DependenciesFinder(ast.NodeVisitor):
|
||||
lhs = self.visit(node.value)
|
||||
while isinstance(lhs, ast.Attribute):
|
||||
lhs = self.visit(lhs.value)
|
||||
if lhs is None or (getattr(lhs, "__name__", "") == "triton" or getattr(lhs, "__name__", "").endswith(".triton")):
|
||||
if lhs is None or (getattr(lhs, "__name__", "") == "triton"
|
||||
or getattr(lhs, "__name__", "").endswith(".triton")):
|
||||
return None
|
||||
return getattr(lhs, node.attr)
|
||||
|
||||
@@ -82,55 +83,26 @@ class DependenciesFinder(ast.NodeVisitor):
|
||||
return
|
||||
if inspect.isbuiltin(func):
|
||||
return
|
||||
if func.__module__ and (func.__module__.startswith('triton.') or '.triton.' in func.__module__):
|
||||
if func.__module__ and (func.__module__.startswith("triton.") or ".triton." in func.__module__):
|
||||
return
|
||||
assert isinstance(func, JITFunction), f"Function \"{func.__name__}\" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this"
|
||||
assert isinstance(
|
||||
func, JITFunction
|
||||
), f'Function "{func.__name__}" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this'
|
||||
if func.hash is None:
|
||||
tree = ast.parse(func.src)
|
||||
finder = DependenciesFinder(func.__globals__, func.src)
|
||||
finder.visit(tree)
|
||||
func.hash = finder.ret
|
||||
noinline = str(getattr(func, 'noinline', False))
|
||||
noinline = str(getattr(func, "noinline", False))
|
||||
self.ret = (self.ret + func.hash + noinline).encode("utf-8")
|
||||
self.ret = hashlib.sha1(self.ret).hexdigest()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# JITFunction
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def version_key():
|
||||
import pkgutil
|
||||
contents = []
|
||||
# frontend
|
||||
with open(__file__, "rb") as f:
|
||||
contents += [hashlib.sha1(f.read()).hexdigest()]
|
||||
# compiler
|
||||
compiler_path = os.path.join(TRITON_PATH, 'compiler')
|
||||
for lib in pkgutil.iter_modules([compiler_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.sha1(f.read()).hexdigest()]
|
||||
# backend
|
||||
libtriton_hash = hashlib.sha1()
|
||||
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
|
||||
while True:
|
||||
chunk = f.read(1024 ** 2)
|
||||
if not chunk:
|
||||
break
|
||||
libtriton_hash.update(chunk)
|
||||
contents.append(libtriton_hash.hexdigest())
|
||||
# language
|
||||
language_path = os.path.join(TRITON_PATH, 'language')
|
||||
for lib in pkgutil.iter_modules([language_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.sha1(f.read()).hexdigest()]
|
||||
# ptxas version
|
||||
ptxas = path_to_ptxas()[0]
|
||||
ptxas_version = hashlib.sha1(subprocess.check_output([ptxas, "--version"])).hexdigest()
|
||||
return '-'.join(TRITON_VERSION) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||
|
||||
|
||||
def _normalize_ty(ty) -> str:
|
||||
if isinstance(ty, type):
|
||||
return ty.__name__
|
||||
@@ -139,6 +111,85 @@ def _normalize_ty(ty) -> str:
|
||||
return repr(ty)
|
||||
|
||||
|
||||
class KernelParam:
|
||||
"""Represents a parameter to a @jit'ed function.
|
||||
|
||||
A parameter is just the name plus metadata; a parameter plus a value is a
|
||||
KernelArg.
|
||||
"""
|
||||
|
||||
def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool):
|
||||
self.num = num
|
||||
self._param = param
|
||||
self.do_not_specialize = do_not_specialize
|
||||
|
||||
@cached_property
|
||||
def name(self):
|
||||
return self._param.name
|
||||
|
||||
@cached_property
|
||||
def annotation(self):
|
||||
if not self._param.annotation or self._param.annotation == inspect.Parameter.empty:
|
||||
return ""
|
||||
return _normalize_ty(self._param.annotation)
|
||||
|
||||
@cached_property
|
||||
def is_constexpr(self):
|
||||
return "constexpr" in self.annotation
|
||||
|
||||
@property
|
||||
def default(self):
|
||||
return self._param.default
|
||||
|
||||
@property
|
||||
def has_default(self):
|
||||
return self._param.default != inspect.Parameter.empty
|
||||
|
||||
|
||||
class KernelArg:
|
||||
"""Represents an argument to a @jit'ed function.
|
||||
|
||||
An argument is a parameter plus a value.
|
||||
"""
|
||||
|
||||
def __init__(self, value, param):
|
||||
self.value = value
|
||||
self.param = param
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.param.name
|
||||
|
||||
def signature_key(self):
|
||||
annotation = self.param.annotation
|
||||
if "Tensor" in annotation:
|
||||
return self.value.dtype
|
||||
elif annotation == "bool":
|
||||
return "i1"
|
||||
elif annotation == "float":
|
||||
return "fp32"
|
||||
else:
|
||||
return JITFunction._key_of(self.value)
|
||||
|
||||
def specialization_key(self):
|
||||
assert not self.param.do_not_specialize
|
||||
|
||||
try:
|
||||
return (self.value.data_ptr() % JITFunction.divisibility == 0, )
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if isinstance(self.value, int):
|
||||
# bool is a subclass of int, so we don't check explicitly above.
|
||||
return (
|
||||
self.value % JITFunction.divisibility == 0,
|
||||
self.value % JITFunction.divisibility_8 == 0,
|
||||
self.value == 1,
|
||||
)
|
||||
|
||||
return (False, )
|
||||
|
||||
|
||||
class KernelInterface(Generic[T]):
|
||||
run: T
|
||||
|
||||
@@ -152,7 +203,6 @@ class KernelInterface(Generic[T]):
|
||||
|
||||
|
||||
class JITFunction(KernelInterface[T]):
|
||||
|
||||
# Hook for inspecting compiled functions and modules
|
||||
cache_hook = None
|
||||
divisibility = 16
|
||||
@@ -169,44 +219,44 @@ class JITFunction(KernelInterface[T]):
|
||||
elif isinstance(arg, bool):
|
||||
return "i1"
|
||||
elif isinstance(arg, int):
|
||||
if -2**31 <= arg and arg <= 2**31 - 1:
|
||||
if -(2**31) <= arg and arg <= 2**31 - 1:
|
||||
return "i32"
|
||||
elif 2**63 <= arg and arg <= 2**64 - 1:
|
||||
return "u64"
|
||||
else:
|
||||
return "i64"
|
||||
elif isinstance(arg, float):
|
||||
return 'fp32'
|
||||
return "fp32"
|
||||
elif arg is None:
|
||||
return None
|
||||
else:
|
||||
raise TypeError(f'Unsupported type {type(arg)} for {arg}')
|
||||
raise TypeError(f"Unsupported type {type(arg)} for {arg}")
|
||||
|
||||
@staticmethod
|
||||
def _device_of(arg):
|
||||
if hasattr(arg, "device"):
|
||||
if hasattr(arg.device, 'type'):
|
||||
return arg.device.type
|
||||
|
||||
return ''
|
||||
try:
|
||||
return arg.device.type
|
||||
except AttributeError:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _pinned_memory_of(arg):
|
||||
if hasattr(arg, "is_pinned"):
|
||||
if isinstance(arg.is_pinned, Callable):
|
||||
return arg.is_pinned()
|
||||
|
||||
return False
|
||||
try:
|
||||
return arg.is_pinned()
|
||||
except (AttributeError, TypeError):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _spec_of(arg):
|
||||
if hasattr(arg, "data_ptr"):
|
||||
return (arg.data_ptr() % JITFunction.divisibility == 0)
|
||||
return arg.data_ptr() % JITFunction.divisibility == 0
|
||||
elif isinstance(arg, int):
|
||||
return (arg % 16 == 0, arg == 1)
|
||||
return (arg is None, )
|
||||
|
||||
# TODO(jlebar): Fold this into the KernelArg class.
|
||||
def _get_config(self, *args):
|
||||
|
||||
def is_divisible_by_16(x):
|
||||
if hasattr(x, "data_ptr"):
|
||||
return x.data_ptr() % JITFunction.divisibility == 0
|
||||
@@ -222,28 +272,38 @@ class JITFunction(KernelInterface[T]):
|
||||
if x is None:
|
||||
return True
|
||||
return False
|
||||
divisible_by_16 = {i for i, arg in enumerate(
|
||||
args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
|
||||
divisible_by_8 = {i for i, arg in enumerate(
|
||||
args) if is_divisible_by_8(arg) and i not in self.do_not_specialize}
|
||||
|
||||
divisible_by_16 = {
|
||||
param.num
|
||||
for param, arg in zip(self.params, args)
|
||||
if is_divisible_by_16(arg) and not param.do_not_specialize
|
||||
}
|
||||
divisible_by_8 = {
|
||||
param.num
|
||||
for param, arg in zip(self.params, args)
|
||||
if is_divisible_by_8(arg) and not param.do_not_specialize
|
||||
}
|
||||
equal_to_1 = {
|
||||
i for i, arg in enumerate(args) if isinstance(
|
||||
arg, int) and not isinstance(
|
||||
arg, bool) and arg == 1 and i not in self.do_not_specialize}
|
||||
param.num
|
||||
for param, arg in zip(self.params, args)
|
||||
if isinstance(arg, int) and not isinstance(arg, bool) and arg == 1 and not param.do_not_specialize
|
||||
}
|
||||
# folded equal_to_1 and None
|
||||
# TODO: method to collect all folded args
|
||||
none_args = {i for i, arg in enumerate(args) if arg is None and i not in self.do_not_specialize}
|
||||
none_args = {param.num for param, arg in zip(self.params, args) if arg is None and not param.do_not_specialize}
|
||||
ids_of_folded_args = equal_to_1 | none_args
|
||||
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])(
|
||||
tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args), tuple(divisible_by_8))
|
||||
return namedtuple("instance_descriptor",
|
||||
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])( #
|
||||
tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args),
|
||||
tuple(divisible_by_8))
|
||||
# return _triton.code_gen.instance_descriptor(divisible_by_16,
|
||||
# equal_to_1)
|
||||
|
||||
@staticmethod
|
||||
def _type_of(key):
|
||||
# None are nullptr -- implicitly converted to *i8
|
||||
# `None` is nullptr. Implicitly convert to *i8.
|
||||
if key is None:
|
||||
return '*i8'
|
||||
return "*i8"
|
||||
dtype_str = str(key).split(".")[-1]
|
||||
tys = {
|
||||
"bool": "i1",
|
||||
@@ -281,21 +341,46 @@ class JITFunction(KernelInterface[T]):
|
||||
constants = dict(zip(self.constexprs, constexpr_key))
|
||||
return constants
|
||||
|
||||
<<<<<<< HEAD
|
||||
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization,enable_fp_fusion, extern_libs, configs):
|
||||
=======
|
||||
def _call_hook(
|
||||
self,
|
||||
key,
|
||||
signature,
|
||||
device,
|
||||
constants,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
num_stages,
|
||||
enable_warp_specialization,
|
||||
enable_fp_fusion,
|
||||
extern_libs,
|
||||
configs,
|
||||
):
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
if JITFunction.cache_hook is None:
|
||||
return False
|
||||
|
||||
name = self.fn.__name__
|
||||
module = self.fn.__module__
|
||||
<<<<<<< HEAD
|
||||
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
|
||||
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, matrix_instr_nonkdim={matrix_instr_nonkdim}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs}), enable_fp_fusion={enable_fp_fusion}]({arg_reprs})"
|
||||
=======
|
||||
arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])])
|
||||
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, enable_warp_specialization={enable_warp_specialization}, enable_fp_fusion={enable_fp_fusion}]({arg_reprs})"
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
key = str(key)
|
||||
|
||||
class LegacyCompiler:
|
||||
|
||||
def __init__(self, module, name):
|
||||
self.module = module
|
||||
self.name = name
|
||||
pass
|
||||
|
||||
<<<<<<< HEAD
|
||||
kwargs = dict(signature=signature, device=device, constants=constants,
|
||||
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs,
|
||||
configs=configs)
|
||||
@@ -326,18 +411,43 @@ class JITFunction(KernelInterface[T]):
|
||||
return 'fp32'
|
||||
else:
|
||||
return self._key_of(arg)
|
||||
=======
|
||||
kwargs = dict(
|
||||
signature=signature,
|
||||
device=device,
|
||||
constants=constants,
|
||||
num_warps=num_warps,
|
||||
num_ctas=num_ctas,
|
||||
num_stages=num_stages,
|
||||
enable_warp_specialization=enable_warp_specialization,
|
||||
enable_fp_fusion=enable_fp_fusion,
|
||||
extern_libs=extern_libs,
|
||||
configs=configs,
|
||||
)
|
||||
|
||||
return JITFunction.cache_hook(
|
||||
key=key,
|
||||
repr=repr,
|
||||
fn=LegacyCompiler(module, name),
|
||||
compile={"key": key, **kwargs},
|
||||
is_manual_warmup=False,
|
||||
already_compiled=False,
|
||||
)
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
|
||||
def _conclude_device_type(self, device_types: List[str], pinned_memory_flags: List[bool]) -> str:
|
||||
device_types = [device_type for device_type in device_types if device_type != '']
|
||||
device_types = [device_type for device_type in device_types if device_type != ""]
|
||||
# Return cuda if one of the input tensors is cuda
|
||||
if 'cuda' in device_types:
|
||||
if "cuda" in device_types:
|
||||
import torch
|
||||
return 'hip' if torch.version.hip else 'cuda'
|
||||
|
||||
is_cpu = all(device_type == 'cpu' for device_type in device_types)
|
||||
return "hip" if torch.version.hip else "cuda"
|
||||
|
||||
is_cpu = all(device_type == "cpu" for device_type in device_types)
|
||||
is_pinned_memory = any(pinned_memory_flag for pinned_memory_flag in pinned_memory_flags)
|
||||
# Return cuda if all the input tensors are cpu while the memory is pinned
|
||||
if is_cpu and is_pinned_memory:
|
||||
<<<<<<< HEAD
|
||||
return 'cuda'
|
||||
|
||||
return device_types[0] if len(device_types) > 0 else 'cuda'
|
||||
@@ -452,16 +562,193 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
|
||||
scope = {"launcher_body": launcher_body}
|
||||
exec(src, scope)
|
||||
return scope[self.fn.__name__]
|
||||
=======
|
||||
return "cuda"
|
||||
|
||||
return device_types[0] if len(device_types) > 0 else "cuda"
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
from ..compiler import CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps
|
||||
|
||||
# Get a compiler-flags arg like `num_warps` and remove it from kwargs.
|
||||
def get_special_arg(name: str, default=None):
|
||||
if name not in kwargs:
|
||||
return default
|
||||
ret = kwargs[name]
|
||||
del kwargs[name]
|
||||
return ret
|
||||
|
||||
grid = get_special_arg("grid")
|
||||
num_warps = get_special_arg("num_warps")
|
||||
num_ctas = get_special_arg("num_ctas", 1)
|
||||
num_stages = get_special_arg("num_stages")
|
||||
enable_warp_specialization = get_special_arg("enable_warp_specialization", False)
|
||||
enable_fp_fusion = get_special_arg("enable_fp_fusion", True)
|
||||
extern_libs = get_special_arg("extern_libs")
|
||||
stream = get_special_arg("stream")
|
||||
warmup = get_special_arg("warmup", False)
|
||||
device = get_special_arg("device")
|
||||
device_type = get_special_arg("device_type")
|
||||
|
||||
# Bind the remaining arguments to `fn`.
|
||||
bound_args = self.signature.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
|
||||
assert len(bound_args.arguments) == len(self.params)
|
||||
args = [KernelArg(arg_value, param) for (_, arg_value), param in zip(bound_args.arguments.items(), self.params)]
|
||||
|
||||
non_constexpr_arg_values = [arg.value for arg in args if not arg.param.is_constexpr]
|
||||
|
||||
sig_key = tuple(arg.signature_key() for arg in args if not arg.param.is_constexpr)
|
||||
spec_key = tuple(arg.specialization_key() for arg in args if not arg.param.do_not_specialize)
|
||||
constexpr_key = tuple(arg.value for arg in args if arg.param.is_constexpr)
|
||||
|
||||
assert num_ctas > 0
|
||||
assert grid is not None
|
||||
if callable(grid):
|
||||
# Arguments are passed as a dict to `grid`, by contract.
|
||||
# TODO(jlebar): In the new launch API, pass the compiler flags as a
|
||||
# second parameter to `grid`.
|
||||
grid = grid(dict(bound_args.arguments))
|
||||
grid_size = len(grid)
|
||||
grid_0 = grid[0]
|
||||
grid_1 = grid[1] if grid_size > 1 else 1
|
||||
grid_2 = grid[2] if grid_size > 2 else 1
|
||||
if device_type is None:
|
||||
device_types = [self._device_of(arg) for arg in non_constexpr_arg_values]
|
||||
device_types = [_device_type for _device_type in device_types if _device_type != ""]
|
||||
device_type = self._conclude_device_type(device_types,
|
||||
[self._pinned_memory_of(arg) for arg in non_constexpr_arg_values])
|
||||
|
||||
device_backend = None
|
||||
if device_type not in ["cuda"]:
|
||||
device_backend = get_backend(device_type)
|
||||
if device_backend is None:
|
||||
raise ValueError("Cannot find backend for " + device_type)
|
||||
|
||||
if device is None:
|
||||
if device_type in ["cuda"]:
|
||||
device = get_current_device()
|
||||
set_current_device(device)
|
||||
else:
|
||||
device = device_backend.get_current_device()
|
||||
device_backend.set_current_device(device)
|
||||
if stream is None and not warmup:
|
||||
if device_type in ["cuda"]:
|
||||
stream = get_cuda_stream(device)
|
||||
else:
|
||||
stream = device_backend.get_stream()
|
||||
|
||||
if num_warps is None:
|
||||
num_warps = get_arch_default_num_warps(device_type)
|
||||
if num_stages is None:
|
||||
num_stages = get_arch_default_num_stages(device_type)
|
||||
|
||||
if device_type in ["cuda"]:
|
||||
version_key = get_cuda_version_key()
|
||||
else:
|
||||
version_key = device_backend.get_version_key()
|
||||
key = (
|
||||
version_key,
|
||||
sig_key,
|
||||
constexpr_key,
|
||||
spec_key,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
num_stages,
|
||||
enable_warp_specialization,
|
||||
enable_fp_fusion,
|
||||
self.debug,
|
||||
)
|
||||
if extern_libs is not None:
|
||||
key = (key, tuple(extern_libs.items()))
|
||||
|
||||
# Kernel is not cached; we have to compile.
|
||||
if key not in self.cache[device]:
|
||||
configs = (self._get_config(*[arg.value for arg in args]), )
|
||||
constants = {
|
||||
arg.param.num: arg.value
|
||||
for arg in args
|
||||
if arg.param.is_constexpr or arg.param.num in configs[0].equal_to_1 or arg.value is None
|
||||
}
|
||||
for i, arg in constants.items():
|
||||
if callable(arg):
|
||||
raise TypeError(f"Callable constexpr at index {i} is not supported")
|
||||
|
||||
# Build kernel signature -- doesn't include constexpr arguments.
|
||||
signature = {
|
||||
arg.param.num: self._type_of(self._key_of(arg.value))
|
||||
for arg in args
|
||||
if not arg.param.is_constexpr
|
||||
}
|
||||
|
||||
if self._call_hook(
|
||||
key,
|
||||
signature,
|
||||
device,
|
||||
constants,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
num_stages,
|
||||
enable_warp_specialization,
|
||||
enable_fp_fusion,
|
||||
extern_libs,
|
||||
configs,
|
||||
):
|
||||
return None
|
||||
|
||||
self.cache[device][key] = compile(
|
||||
self,
|
||||
signature=signature,
|
||||
device=device,
|
||||
constants=constants,
|
||||
num_warps=num_warps,
|
||||
num_ctas=num_ctas,
|
||||
num_stages=num_stages,
|
||||
enable_warp_specialization=enable_warp_specialization,
|
||||
enable_fp_fusion=enable_fp_fusion,
|
||||
extern_libs=extern_libs,
|
||||
configs=configs,
|
||||
debug=self.debug,
|
||||
device_type=device_type,
|
||||
)
|
||||
|
||||
bin = self.cache[device][key]
|
||||
if not warmup:
|
||||
bin.c_wrapper(
|
||||
grid_0,
|
||||
grid_1,
|
||||
grid_2,
|
||||
bin.num_warps,
|
||||
bin.num_ctas,
|
||||
bin.clusterDims[0],
|
||||
bin.clusterDims[1],
|
||||
bin.clusterDims[2],
|
||||
bin.shared,
|
||||
stream,
|
||||
bin.cu_function,
|
||||
CompiledKernel.launch_enter_hook,
|
||||
CompiledKernel.launch_exit_hook,
|
||||
bin,
|
||||
*bin.assemble_tensormap_to_arg(non_constexpr_arg_values),
|
||||
)
|
||||
return bin
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
|
||||
def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None):
|
||||
do_not_specialize = do_not_specialize if do_not_specialize else []
|
||||
|
||||
self.fn = fn
|
||||
self.module = fn.__module__
|
||||
self.version = version
|
||||
# function signature information
|
||||
signature = inspect.signature(fn)
|
||||
self.arg_names = [v.name for v in signature.parameters.values()]
|
||||
self.arg_defaults = [v.default for v in signature.parameters.values()]
|
||||
self.has_defaults = any(v != inspect._empty for v in self.arg_defaults)
|
||||
self.signature = inspect.signature(fn)
|
||||
self.do_not_specialize = do_not_specialize
|
||||
|
||||
self.params = []
|
||||
for i, param in enumerate(self.signature.parameters.values()):
|
||||
dns = do_not_specialize and (i in do_not_specialize or param.name in do_not_specialize)
|
||||
self.params.append(KernelParam(i, param, dns))
|
||||
|
||||
# function source code (without decorators)
|
||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||
self.src = self.src[self.src.find("def"):]
|
||||
@@ -470,22 +757,18 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
|
||||
self.hash = None
|
||||
# JITFunction can be instantiated as kernel
|
||||
# when called with a grid using __getitem__
|
||||
self.kernel_decorators = []
|
||||
self.kernel = None
|
||||
self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug
|
||||
self.noinline = noinline
|
||||
# annotations
|
||||
self.__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
|
||||
# index of constexprs
|
||||
self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if 'constexpr' in ty]
|
||||
# specialization hints
|
||||
regular_args = [arg for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
|
||||
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
|
||||
self.do_not_specialize = {regular_args.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize}
|
||||
|
||||
# tma info
|
||||
self.tensormaps_info = TMAInfos()
|
||||
# launcher
|
||||
self.run = self._make_launcher()
|
||||
|
||||
# TODO(jlebar): Remove uses of these fields outside this file, then
|
||||
# remove the fields here.
|
||||
self.arg_names = [p.name for p in self.params]
|
||||
self.constexprs = [p.num for p in self.params if p.is_constexpr]
|
||||
|
||||
# re-use docs of wrapped function
|
||||
self.__doc__ = fn.__doc__
|
||||
self.__name__ = fn.__name__
|
||||
@@ -498,7 +781,7 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
|
||||
if self.hash is None:
|
||||
dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src)
|
||||
dependencies_finder.visit(self.parse())
|
||||
self.hash = dependencies_finder.ret + version_key()
|
||||
self.hash = dependencies_finder.ret
|
||||
return self.hash
|
||||
|
||||
def warmup(self, *args, **kwargs):
|
||||
@@ -518,14 +801,10 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
|
||||
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
# - when kernel decorators change, cached kernel
|
||||
# needs to be cleared
|
||||
if name == 'kernel_decorators':
|
||||
self.kernel = None
|
||||
super(JITFunction, self).__setattr__(name, value)
|
||||
# - when `.src` attribute is set, cache path needs
|
||||
# to be reinitialized
|
||||
if name == 'src':
|
||||
if name == "src":
|
||||
self.hash = None
|
||||
|
||||
def __repr__(self):
|
||||
@@ -591,12 +870,14 @@ def jit(
|
||||
debug=debug,
|
||||
noinline=noinline,
|
||||
)
|
||||
|
||||
if fn is not None:
|
||||
return decorator(fn)
|
||||
|
||||
else:
|
||||
return decorator
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Utilities for mocking tensors
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -607,10 +888,10 @@ class MockTensor:
|
||||
Can be used in place of real tensors when calling:
|
||||
kernel.warmup(MockTensor(torch.float32), ...)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def wrap_dtype(arg):
|
||||
if arg.__class__.__name__ == "dtype" and\
|
||||
arg.__module__ == "torch":
|
||||
if arg.__class__.__name__ == "dtype" and arg.__module__ == "torch":
|
||||
return MockTensor(arg)
|
||||
return arg
|
||||
|
||||
@@ -623,6 +904,7 @@ class MockTensor:
|
||||
|
||||
|
||||
class TensorWrapper:
|
||||
|
||||
def __init__(self, base, dtype):
|
||||
self.dtype = dtype
|
||||
self.base = base
|
||||
@@ -637,7 +919,7 @@ class TensorWrapper:
|
||||
return self.base.stride(i)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'TensorWrapper[{self.dtype}]({self.base})'
|
||||
return f"TensorWrapper[{self.dtype}]({self.base})"
|
||||
|
||||
def element_size(self):
|
||||
return self.base.element_size()
|
||||
@@ -655,4 +937,4 @@ def reinterpret(tensor, dtype):
|
||||
# A new wrapper is needed around an unwrapped tensor.
|
||||
return TensorWrapper(tensor, dtype)
|
||||
else:
|
||||
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')
|
||||
raise TypeError(f"Cannot reinterpret a {type(tensor)}.")
|
||||
|
||||
@@ -78,10 +78,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None):
|
||||
return torch.mean(torch.tensor(ret)).item()
|
||||
|
||||
|
||||
def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
|
||||
quantiles=None,
|
||||
fast_flush=True,
|
||||
return_mode="mean"):
|
||||
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean"):
|
||||
assert return_mode in ["min", "max", "mean", "median"]
|
||||
import torch
|
||||
"""
|
||||
@@ -261,11 +258,12 @@ class Benchmark:
|
||||
|
||||
|
||||
class Mark:
|
||||
|
||||
def __init__(self, fn, benchmarks):
|
||||
self.fn = fn
|
||||
self.benchmarks = benchmarks
|
||||
|
||||
def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, **kwrags):
|
||||
def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False, **kwrags):
|
||||
import os
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
@@ -321,24 +319,36 @@ class Mark:
|
||||
if save_path:
|
||||
plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
|
||||
df = df[x_names + bench.line_names]
|
||||
if diff_col and df.shape[1] == 2:
|
||||
col0, col1 = df.columns.tolist()
|
||||
df['Diff'] = df[col1] - df[col0]
|
||||
|
||||
if print_data:
|
||||
print(bench.plot_name + ':')
|
||||
print(df)
|
||||
if save_path:
|
||||
df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False)
|
||||
return df
|
||||
|
||||
def run(self, show_plots=False, print_data=False, save_path='', **kwargs):
|
||||
def run(self, show_plots=False, print_data=False, save_path='', return_df=False, **kwargs):
|
||||
has_single_bench = isinstance(self.benchmarks, Benchmark)
|
||||
benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks
|
||||
result_dfs = []
|
||||
if save_path:
|
||||
html = open(os.path.join(save_path, "results.html"), "w")
|
||||
html.write("<html><body>\n")
|
||||
for bench in benchmarks:
|
||||
self._run(bench, save_path, show_plots, print_data, **kwargs)
|
||||
result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs))
|
||||
if save_path:
|
||||
html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
|
||||
if save_path:
|
||||
html.write("</body></html>\n")
|
||||
if return_df:
|
||||
if has_single_bench:
|
||||
return result_dfs[0]
|
||||
else:
|
||||
return result_dfs
|
||||
return None
|
||||
|
||||
|
||||
def perf_report(benchmarks):
|
||||
@@ -393,12 +403,15 @@ def get_max_tensorcore_tflops(dtype, clock_rate, backend=None, device=None):
|
||||
tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
|
||||
return tflops
|
||||
|
||||
|
||||
# create decorator that wraps test function into
|
||||
# a cuda-memcheck system call
|
||||
|
||||
|
||||
def cuda_memcheck(**target_kwargs):
|
||||
|
||||
def decorator(test_fn):
|
||||
|
||||
@functools.wraps(test_fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
import psutil
|
||||
@@ -416,7 +429,9 @@ def cuda_memcheck(**target_kwargs):
|
||||
assert "ERROR SUMMARY: 0 errors" in str(out.stdout)
|
||||
else:
|
||||
test_fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@@ -424,22 +439,18 @@ def cuda_memcheck(**target_kwargs):
|
||||
def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215):
|
||||
try:
|
||||
subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"])
|
||||
subprocess.check_output(
|
||||
[
|
||||
"nvidia-smi",
|
||||
"-i",
|
||||
"0",
|
||||
f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}",
|
||||
]
|
||||
)
|
||||
subprocess.check_output(
|
||||
[
|
||||
"nvidia-smi",
|
||||
"-i",
|
||||
"0",
|
||||
f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}",
|
||||
]
|
||||
)
|
||||
subprocess.check_output([
|
||||
"nvidia-smi",
|
||||
"-i",
|
||||
"0",
|
||||
f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}",
|
||||
])
|
||||
subprocess.check_output([
|
||||
"nvidia-smi",
|
||||
"-i",
|
||||
"0",
|
||||
f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}",
|
||||
])
|
||||
cur_sm_clock = nvsmi(["clocks.current.sm"])[0]
|
||||
cur_mem_clock = nvsmi(["clocks.current.memory"])[0]
|
||||
assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz"
|
||||
|
||||
@@ -141,8 +141,7 @@ class ExternLibrary(ABC):
|
||||
f.write(file_str)
|
||||
f.close()
|
||||
if self._format:
|
||||
subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file],
|
||||
stdout=subprocess.PIPE).communicate()
|
||||
subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file], stdout=subprocess.PIPE).communicate()
|
||||
subprocess.Popen(["isort", output_file], stdout=subprocess.PIPE).communicate()
|
||||
|
||||
|
||||
@@ -208,56 +207,36 @@ class Libdevice(ExternLibrary):
|
||||
|
||||
# Group functions together by renaming.
|
||||
renaming = {
|
||||
'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh',
|
||||
'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': 'add_rn',
|
||||
'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru',
|
||||
'dadd_rz': 'add_rz', 'fadd_rz': 'add_rz', 'asinf': 'asin',
|
||||
'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2',
|
||||
'atanhf': 'atanh', 'brevll': 'brev', 'cbrtf': 'cbrt',
|
||||
'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign',
|
||||
'cosf': 'cos', 'coshf': 'cosh', 'cospif': 'cospi',
|
||||
'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1',
|
||||
'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn',
|
||||
'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', 'ddiv_ru': 'div_ru',
|
||||
'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf',
|
||||
'erfcf': 'erfc', 'erfcinvf': 'erfcinv', 'erfcxf': 'erfcx',
|
||||
'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10',
|
||||
'exp2f': 'exp2', 'expm1f': 'expm1', 'fabsf': 'abs',
|
||||
'fabs': 'abs', 'fast_fdividef': 'fast_dividef',
|
||||
'fdimf': 'fdim', 'ffsll': 'ffs', 'floorf': 'floor',
|
||||
'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn',
|
||||
'fmaf_ru': 'fma_ru', 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod',
|
||||
'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb',
|
||||
'isinff': 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan',
|
||||
'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn',
|
||||
'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint',
|
||||
'llroundf': 'llround', 'logf': 'log', 'log10f': 'log10',
|
||||
'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb',
|
||||
'umax': 'max', 'llmax': 'max', 'ullmax': 'max', 'fmaxf': 'max',
|
||||
'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min',
|
||||
'fminf': 'min', 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd',
|
||||
'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', 'dmul_ru': 'mul_ru',
|
||||
'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz',
|
||||
'umul24': 'mul24', 'umulhi': 'mulhi', 'mul64hi': 'mulhi',
|
||||
'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': 'nextafter',
|
||||
'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf',
|
||||
'normcdfinvf': 'normcdfinv', 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow',
|
||||
'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', 'drcp_rd': 'rcp_rd',
|
||||
'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru',
|
||||
'drcp_ru': 'rcp_ru', 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz',
|
||||
'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot',
|
||||
'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d',
|
||||
'roundf': 'round', 'rsqrtf': 'rsqrt', 'frsqrt_rn': 'rsqrt_rn',
|
||||
'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit',
|
||||
'signbitd': 'signbit', 'sinf': 'sin', 'sinhf': 'sinh',
|
||||
'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd',
|
||||
'dsqrt_rd': 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn',
|
||||
'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', 'fsqrt_rz': 'sqrt_rz',
|
||||
'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd',
|
||||
'fsub_rn': 'sub_rn', 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru',
|
||||
'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz',
|
||||
'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc',
|
||||
'y0f': 'y0', 'y1f': 'y1', 'ynf': 'yn'
|
||||
'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh', 'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn':
|
||||
'add_rn', 'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru', 'dadd_rz': 'add_rz', 'fadd_rz':
|
||||
'add_rz', 'asinf': 'asin', 'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2', 'atanhf': 'atanh',
|
||||
'brevll': 'brev', 'cbrtf': 'cbrt', 'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign', 'cosf': 'cos',
|
||||
'coshf': 'cosh', 'cospif': 'cospi', 'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1',
|
||||
'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn', 'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru',
|
||||
'ddiv_ru': 'div_ru', 'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf', 'erfcf': 'erfc', 'erfcinvf':
|
||||
'erfcinv', 'erfcxf': 'erfcx', 'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10', 'exp2f': 'exp2',
|
||||
'expm1f': 'expm1', 'fabsf': 'abs', 'fabs': 'abs', 'fast_fdividef': 'fast_dividef', 'fdimf': 'fdim', 'ffsll':
|
||||
'ffs', 'floorf': 'floor', 'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn', 'fmaf_ru': 'fma_ru',
|
||||
'fmaf_rz': 'fma_rz', 'fmodf': 'fmod', 'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb', 'isinff':
|
||||
'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan', 'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn',
|
||||
'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint', 'llroundf': 'llround', 'logf': 'log', 'log10f':
|
||||
'log10', 'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb', 'umax': 'max', 'llmax': 'max', 'ullmax':
|
||||
'max', 'fmaxf': 'max', 'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min', 'fminf': 'min',
|
||||
'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd', 'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn',
|
||||
'dmul_ru': 'mul_ru', 'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz', 'umul24': 'mul24',
|
||||
'umulhi': 'mulhi', 'mul64hi': 'mulhi', 'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf':
|
||||
'nextafter', 'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf', 'normcdfinvf': 'normcdfinv',
|
||||
'popcll': 'popc', 'powif': 'pow', 'powi': 'pow', 'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd',
|
||||
'drcp_rd': 'rcp_rd', 'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru', 'drcp_ru': 'rcp_ru',
|
||||
'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz', 'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot',
|
||||
'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d', 'roundf': 'round', 'rsqrtf': 'rsqrt',
|
||||
'frsqrt_rn': 'rsqrt_rn', 'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit', 'signbitd': 'signbit',
|
||||
'sinf': 'sin', 'sinhf': 'sinh', 'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd', 'dsqrt_rd':
|
||||
'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn', 'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru',
|
||||
'fsqrt_rz': 'sqrt_rz', 'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd', 'fsub_rn': 'sub_rn',
|
||||
'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru', 'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz',
|
||||
'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc', 'y0f': 'y0', 'y1f': 'y1', 'ynf':
|
||||
'yn'
|
||||
}
|
||||
|
||||
for symbol in self._symbols.values():
|
||||
@@ -347,8 +326,7 @@ class LLVMDisassembler:
|
||||
self._ll_file = "/tmp/extern_lib.ll"
|
||||
|
||||
def disasm(self, lib_path: str) -> None:
|
||||
subprocess.Popen([self._path, lib_path, "-o", self.ll_file],
|
||||
stdout=subprocess.PIPE).communicate()
|
||||
subprocess.Popen([self._path, lib_path, "-o", self.ll_file], stdout=subprocess.PIPE).communicate()
|
||||
|
||||
@property
|
||||
def ll_file(self) -> str:
|
||||
|
||||
@@ -40,10 +40,13 @@ if __name__ == "__main__":
|
||||
|
||||
# command-line arguments
|
||||
parser = ArgumentParser(description=desc)
|
||||
parser.add_argument("path", help="Path to Python source containing desired kernel in its scope. File will be executed.")
|
||||
parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile", required=True)
|
||||
parser.add_argument("path",
|
||||
help="Path to Python source containing desired kernel in its scope. File will be executed.")
|
||||
parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile",
|
||||
required=True)
|
||||
parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel")
|
||||
parser.add_argument("--num-stages", "-ns", type=int, default=3, help="Number of stages (meta-parameter of the kernel)")
|
||||
parser.add_argument("--num-stages", "-ns", type=int, default=3,
|
||||
help="Number of stages (meta-parameter of the kernel)")
|
||||
parser.add_argument("--out-name", "-on", type=str, default=None, help="Out name for the compiled kernel")
|
||||
parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename")
|
||||
parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True)
|
||||
@@ -104,7 +107,8 @@ if __name__ == "__main__":
|
||||
config = triton.compiler.instance_descriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1)
|
||||
for i in equal_to_1:
|
||||
constexprs.update({i: 1})
|
||||
ccinfo = triton.compile(kernel, signature=signature, constants=constexprs, configs=[config], num_warps=args.num_warps, num_stages=args.num_stages)
|
||||
ccinfo = triton.compile(kernel, signature=signature, constants=constexprs, configs=[config],
|
||||
num_warps=args.num_warps, num_stages=args.num_stages)
|
||||
arg_names = []
|
||||
arg_types = []
|
||||
for i in signature.keys():
|
||||
|
||||
@@ -27,6 +27,7 @@ class KernelLinkerMeta:
|
||||
|
||||
|
||||
class HeaderParser:
|
||||
|
||||
def __init__(self) -> None:
|
||||
import re
|
||||
|
||||
@@ -42,7 +43,6 @@ class HeaderParser:
|
||||
self.kernels = defaultdict(list)
|
||||
|
||||
def extract_linker_meta(self, header: str):
|
||||
|
||||
for ln in header.splitlines():
|
||||
if ln.startswith("//"):
|
||||
m = self.linker_directives.match(ln)
|
||||
@@ -76,7 +76,7 @@ class HeaderParser:
|
||||
m = self.c_sig.findall(c_sig)
|
||||
if len(m):
|
||||
tys, args = [], []
|
||||
for (ty, arg_name) in m:
|
||||
for ty, arg_name in m:
|
||||
tys.append(ty)
|
||||
args.append(arg_name)
|
||||
return tys, args
|
||||
@@ -84,7 +84,7 @@ class HeaderParser:
|
||||
raise LinkerError(f"{c_sig} is not a valid argument signature")
|
||||
|
||||
def _match_suffix(self, suffix: str, c_sig: str):
|
||||
args = c_sig.split(',')
|
||||
args = c_sig.split(",")
|
||||
s2i = {"c": 1, "d": 16}
|
||||
num_specs = 0
|
||||
sizes = []
|
||||
@@ -110,7 +110,7 @@ class HeaderParser:
|
||||
if name in self.kernels:
|
||||
last: KernelLinkerMeta = self.kernels[name][-1]
|
||||
|
||||
for (cur, new_) in zip(last.arg_ctypes, ker.arg_ctypes):
|
||||
for cur, new_ in zip(last.arg_ctypes, ker.arg_ctypes):
|
||||
if cur != new_:
|
||||
raise LinkerError(
|
||||
f"Mismatched signature for kernel {name}: \n\texisting sig is: {','.join(last.arg_ctypes)}\n\tcurrent is: {','.join(ker.arg_ctypes)}"
|
||||
@@ -152,7 +152,7 @@ void unload_{meta.orig_kernel_name}();
|
||||
# generate dispatcher function for kernels with different meta-parameter and constant values
|
||||
def make_default_algo_kernel(meta: KernelLinkerMeta) -> str:
|
||||
src = f"CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}){{\n"
|
||||
src += f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n"
|
||||
src += (f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n")
|
||||
src += "}\n"
|
||||
return src
|
||||
|
||||
@@ -164,12 +164,22 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -
|
||||
src += f"CUresult {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, {gen_signature(meta)});\n"
|
||||
src += "\n"
|
||||
|
||||
src += f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{"
|
||||
src += (f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{")
|
||||
src += "\n"
|
||||
for meta in sorted(metas, key=lambda m: -m.num_specs):
|
||||
cond_fn = lambda val, hint: f"({val} % {hint} == 0)" if hint == 16 else f"({val} == {hint})" if hint == 1 else None
|
||||
conds = " && ".join([cond_fn(val, hint) for val, hint in zip(meta.arg_names, meta.sizes) if hint is not None])
|
||||
src += f" if ({conds})\n"
|
||||
cond_fn = ( #
|
||||
lambda val, hint: f"({val} % {hint} == 0)" #
|
||||
if hint == 16 #
|
||||
else f"({val} == {hint})" #
|
||||
if hint == 1 #
|
||||
else None)
|
||||
conds = " && ".join([ #
|
||||
cond_fn(val, hint) #
|
||||
for val, hint in zip(meta.arg_names, meta.sizes) #
|
||||
if hint is not None
|
||||
])
|
||||
src += (f" if ({conds})\n" if any(meta.sizes) else "if (1)\n"
|
||||
) # Edge case where no specializations hence no dispatching required
|
||||
arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1]
|
||||
src += f" return {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(stream, {', '.join(arg_names)});\n"
|
||||
src += "\n"
|
||||
@@ -183,7 +193,7 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -
|
||||
src += f"void {mode}_{name}() {{"
|
||||
src += "\n"
|
||||
for meta in sorted(metas, key=lambda m: -m.num_specs):
|
||||
src += f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n"
|
||||
src += (f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n")
|
||||
src += "}\n"
|
||||
return src
|
||||
|
||||
@@ -252,7 +262,12 @@ if __name__ == "__main__":
|
||||
help="Paths to header files to link. Must include linker directive annotations (autogenerated by ttc)",
|
||||
)
|
||||
parser.add_argument("--out", "-o", type=Path, help="Out filename")
|
||||
parser.add_argument("--prefix", type=str, default="", help="String to prefix kernel dispatcher names")
|
||||
parser.add_argument(
|
||||
"--prefix",
|
||||
type=str,
|
||||
default="",
|
||||
help="String to prefix kernel dispatcher names",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# metadata
|
||||
|
||||
@@ -25,14 +25,13 @@ import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def add_kernel(
|
||||
x_ptr, # *Pointer* to first input vector.
|
||||
y_ptr, # *Pointer* to second input vector.
|
||||
output_ptr, # *Pointer* to output vector.
|
||||
n_elements, # Size of the vector.
|
||||
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
|
||||
# NOTE: `constexpr` so it can be used as a shape value.
|
||||
):
|
||||
def add_kernel(x_ptr, # *Pointer* to first input vector.
|
||||
y_ptr, # *Pointer* to second input vector.
|
||||
output_ptr, # *Pointer* to output vector.
|
||||
n_elements, # Size of the vector.
|
||||
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
|
||||
# NOTE: `constexpr` so it can be used as a shape value.
|
||||
):
|
||||
# There are multiple 'programs' processing different data. We identify which program
|
||||
# we are here:
|
||||
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
|
||||
@@ -66,7 +65,7 @@ def add(x: torch.Tensor, y: torch.Tensor):
|
||||
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
|
||||
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
|
||||
# In this case, we use a 1D grid where the size is the number of blocks:
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
|
||||
# NOTE:
|
||||
# - Each torch.tensor object is implicitly converted into a pointer to its first element.
|
||||
# - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel.
|
||||
@@ -88,10 +87,8 @@ output_torch = x + y
|
||||
output_triton = add(x, y)
|
||||
print(output_torch)
|
||||
print(output_triton)
|
||||
print(
|
||||
f'The maximum difference between torch and triton is '
|
||||
f'{torch.max(torch.abs(output_torch - output_triton))}'
|
||||
)
|
||||
print(f'The maximum difference between torch and triton is '
|
||||
f'{torch.max(torch.abs(output_torch - output_triton))}')
|
||||
|
||||
# %%
|
||||
# Seems like we're good to go!
|
||||
@@ -108,9 +105,7 @@ print(
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['size'], # Argument names to use as an x-axis for the plot.
|
||||
x_vals=[
|
||||
2 ** i for i in range(12, 28, 1)
|
||||
], # Different possible values for `x_name`.
|
||||
x_vals=[2**i for i in range(12, 28, 1)], # Different possible values for `x_name`.
|
||||
x_log=True, # x axis is logarithmic.
|
||||
line_arg='provider', # Argument name whose value corresponds to a different line in the plot.
|
||||
line_vals=['triton', 'torch'], # Possible values for `line_arg`.
|
||||
@@ -119,8 +114,7 @@ print(
|
||||
ylabel='GB/s', # Label name for the y-axis.
|
||||
plot_name='vector-add-performance', # Name for the plot. Used also as a file name for saving the plot.
|
||||
args={}, # Values for function arguments not in `x_names` and `y_name`.
|
||||
)
|
||||
)
|
||||
))
|
||||
def benchmark(size, provider):
|
||||
x = torch.rand(size, device='cuda', dtype=torch.float32)
|
||||
y = torch.rand(size, device='cuda', dtype=torch.float32)
|
||||
|
||||
@@ -71,10 +71,7 @@ def naive_softmax(x):
|
||||
|
||||
|
||||
@triton.jit
|
||||
def softmax_kernel(
|
||||
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
|
||||
BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr):
|
||||
# The rows of the softmax are independent, so we parallelize across those
|
||||
row_idx = tl.program_id(0)
|
||||
# The stride represents how much we need to increase the pointer to advance 1 row
|
||||
@@ -118,7 +115,7 @@ def softmax(x):
|
||||
y = torch.empty_like(x)
|
||||
# Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o
|
||||
# f the input matrix
|
||||
softmax_kernel[(n_rows,)](
|
||||
softmax_kernel[(n_rows, )](
|
||||
y,
|
||||
x,
|
||||
x.stride(0),
|
||||
@@ -158,9 +155,7 @@ assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['N'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[
|
||||
128 * i for i in range(2, 100)
|
||||
], # different possible values for `x_name`
|
||||
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
|
||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||
line_vals=[
|
||||
'triton',
|
||||
@@ -176,8 +171,7 @@ assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
|
||||
ylabel="GB/s", # label name for the y-axis
|
||||
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
|
||||
)
|
||||
)
|
||||
))
|
||||
def benchmark(M, N, provider):
|
||||
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
@@ -165,6 +165,7 @@ import pytest
|
||||
# provided configs
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
<<<<<<< HEAD
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
@@ -179,6 +180,24 @@ import pytest
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, num_warps=8, num_stages=0),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3}, num_warps=4, num_stages=0),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8}, num_warps=4, num_stages=0),
|
||||
=======
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
|
||||
num_warps=8),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
|
||||
num_warps=2),
|
||||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
|
||||
num_warps=2),
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
@@ -187,6 +206,7 @@ import pytest
|
||||
})
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
<<<<<<< HEAD
|
||||
# Pointers to matrices
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
# Matrix dimensions
|
||||
@@ -202,6 +222,22 @@ def matmul_kernel(
|
||||
EVEN_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
ACTIVATION: tl.constexpr,
|
||||
=======
|
||||
# Pointers to matrices
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
# Matrix dimensions
|
||||
M, N, K,
|
||||
# The stride variables represent how much to increase the ptr by when moving by 1
|
||||
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
|
||||
# by to get the element one row down (A has M rows).
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
|
||||
GROUP_SIZE_M: tl.constexpr, #
|
||||
ACTIVATION: tl.constexpr #
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
):
|
||||
"""Kernel for computing the matmul C = A x B.
|
||||
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
||||
@@ -300,16 +336,14 @@ def matmul(a, b, activation=""):
|
||||
# Allocates output.
|
||||
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
|
||||
# 1D launch kernel where each block gets its own program.
|
||||
grid = lambda META: (
|
||||
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
|
||||
)
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
|
||||
matmul_kernel[grid](
|
||||
a, b, c,
|
||||
M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
ACTIVATION=activation
|
||||
a, b, c, #
|
||||
M, N, K, #
|
||||
a.stride(0), a.stride(1), #
|
||||
b.stride(0), b.stride(1), #
|
||||
c.stride(0), c.stride(1), #
|
||||
ACTIVATION=activation #
|
||||
)
|
||||
return c
|
||||
|
||||
@@ -363,6 +397,7 @@ verbose = False
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['M', 'N', 'K'], # Argument names to use as an x-axis for the plot
|
||||
<<<<<<< HEAD
|
||||
x_vals=[
|
||||
(1024, 1024, 1024),
|
||||
(2048, 2048, 2048),
|
||||
@@ -370,6 +405,9 @@ verbose = False
|
||||
(8192, 8192, 8192),
|
||||
(9728, 8192, 65536)
|
||||
], # Different possible values for `x_name`
|
||||
=======
|
||||
x_vals=[128 * i for i in range(2, 33)], # Different possible values for `x_name`
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
line_arg='provider', # Argument name whose value corresponds to a different line in the plot
|
||||
# Possible values for `line_arg`
|
||||
line_vals=['rocblas', 'triton'],
|
||||
@@ -380,8 +418,7 @@ verbose = False
|
||||
ylabel="TFLOPS", # Label name for the y-axis
|
||||
plot_name="matmul-performance", # Name for the plot, used also as a file name for saving the plot.
|
||||
args={},
|
||||
)
|
||||
)
|
||||
))
|
||||
def benchmark(M, N, K, provider):
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
||||
|
||||
@@ -32,7 +32,6 @@ In doing so, you will learn about:
|
||||
#
|
||||
# Let's first take a look at the baseline implementation.
|
||||
|
||||
|
||||
import tabulate
|
||||
import torch
|
||||
|
||||
@@ -66,22 +65,22 @@ def dropout(x, x_keep, p):
|
||||
output = torch.empty_like(x)
|
||||
assert x.is_contiguous()
|
||||
n_elements = x.numel()
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
|
||||
_dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
|
||||
return output
|
||||
|
||||
|
||||
# Input tensor
|
||||
x = torch.randn(size=(10,)).cuda()
|
||||
x = torch.randn(size=(10, )).cuda()
|
||||
# Dropout mask
|
||||
p = 0.5
|
||||
x_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda()
|
||||
x_keep = (torch.rand(size=(10, )) > p).to(torch.int32).cuda()
|
||||
#
|
||||
output = dropout(x, x_keep=x_keep, p=p)
|
||||
print(tabulate.tabulate([
|
||||
["input"] + x.tolist(),
|
||||
["keep mask"] + x_keep.tolist(),
|
||||
["output"] + output.tolist()
|
||||
["output"] + output.tolist(),
|
||||
]))
|
||||
|
||||
# %%
|
||||
@@ -134,23 +133,24 @@ def seeded_dropout(x, p, seed):
|
||||
output = torch.empty_like(x)
|
||||
assert x.is_contiguous()
|
||||
n_elements = x.numel()
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
|
||||
_seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)
|
||||
return output
|
||||
|
||||
|
||||
x = torch.randn(size=(10,)).cuda()
|
||||
x = torch.randn(size=(10, )).cuda()
|
||||
# Compare this to the baseline - dropout mask is never instantiated!
|
||||
output = seeded_dropout(x, p=0.5, seed=123)
|
||||
output2 = seeded_dropout(x, p=0.5, seed=123)
|
||||
output3 = seeded_dropout(x, p=0.5, seed=512)
|
||||
|
||||
print(tabulate.tabulate([
|
||||
["input"] + x.tolist(),
|
||||
["output (seed = 123)"] + output.tolist(),
|
||||
["output (seed = 123)"] + output2.tolist(),
|
||||
["output (seed = 512)"] + output3.tolist()
|
||||
]))
|
||||
print(
|
||||
tabulate.tabulate([
|
||||
["input"] + x.tolist(),
|
||||
["output (seed = 123)"] + output.tolist(),
|
||||
["output (seed = 123)"] + output2.tolist(),
|
||||
["output (seed = 512)"] + output3.tolist(),
|
||||
]))
|
||||
|
||||
# %%
|
||||
# Et Voilà! We have a triton kernel that applies the same dropout mask provided the seed is the same!
|
||||
|
||||
@@ -126,24 +126,22 @@ def _layer_norm_fwd_fused(
|
||||
# In Stage 2, the buffers are further reduced to compute the final :math:`\nabla_{w}` and :math:`\nabla_{b}`.
|
||||
# In the following implementation, Stage 1 is implemented by the function :code:`_layer_norm_bwd_dx_fused` and Stage 2 is implemented by the function :code:`_layer_norm_bwd_dwdb`.
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _layer_norm_bwd_dx_fused(
|
||||
DX, # pointer to the input gradient
|
||||
DY, # pointer to the output gradient
|
||||
DW, # pointer to the partial sum of weights gradient
|
||||
DB, # pointer to the partial sum of biases gradient
|
||||
X, # pointer to the input
|
||||
W, # pointer to the weights
|
||||
B, # pointer to the biases
|
||||
Mean, # pointer to the mean
|
||||
Rstd, # pointer to the 1/std
|
||||
Lock, # pointer to the lock
|
||||
stride, # how much to increase the pointer when moving by 1 row
|
||||
N, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr
|
||||
):
|
||||
def _layer_norm_bwd_dx_fused(DX, # pointer to the input gradient
|
||||
DY, # pointer to the output gradient
|
||||
DW, # pointer to the partial sum of weights gradient
|
||||
DB, # pointer to the partial sum of biases gradient
|
||||
X, # pointer to the input
|
||||
W, # pointer to the weights
|
||||
B, # pointer to the biases
|
||||
Mean, # pointer to the mean
|
||||
Rstd, # pointer to the 1/std
|
||||
Lock, # pointer to the lock
|
||||
stride, # how much to increase the pointer when moving by 1 row
|
||||
N, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
|
||||
# Map the program id to the elements of X, DX, and DY it should compute.
|
||||
row = tl.program_id(0)
|
||||
cols = tl.arange(0, BLOCK_SIZE_N)
|
||||
@@ -192,16 +190,13 @@ def _layer_norm_bwd_dx_fused(
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _layer_norm_bwd_dwdb(
|
||||
DW, # pointer to the partial sum of weights gradient
|
||||
DB, # pointer to the partial sum of biases gradient
|
||||
FINAL_DW, # pointer to the weights gradient
|
||||
FINAL_DB, # pointer to the biases gradient
|
||||
M, # GROUP_SIZE_M
|
||||
N, # number of columns
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr
|
||||
):
|
||||
def _layer_norm_bwd_dwdb(DW, # pointer to the partial sum of weights gradient
|
||||
DB, # pointer to the partial sum of biases gradient
|
||||
FINAL_DW, # pointer to the weights gradient
|
||||
FINAL_DB, # pointer to the biases gradient
|
||||
M, # GROUP_SIZE_M
|
||||
N, # number of columns
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
|
||||
# Map the program id to the elements of DW and DB it should compute.
|
||||
pid = tl.program_id(0)
|
||||
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
@@ -258,9 +253,10 @@ class LayerNorm(torch.autograd.Function):
|
||||
else:
|
||||
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
||||
# enqueue kernel
|
||||
_layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd,
|
||||
x_arg.stride(0), N, eps,
|
||||
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1)
|
||||
_layer_norm_fwd_fused[(M, )]( #
|
||||
x_arg, y, weight, bias, mean, rstd, #
|
||||
x_arg.stride(0), N, eps, #
|
||||
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1)
|
||||
ctx.save_for_backward(x, weight, bias, mean, rstd)
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
@@ -280,23 +276,25 @@ class LayerNorm(torch.autograd.Function):
|
||||
locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda')
|
||||
_dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
|
||||
_db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
|
||||
dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
|
||||
db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
|
||||
dw = torch.empty((w.shape[0], ), dtype=w.dtype, device=w.device)
|
||||
db = torch.empty((w.shape[0], ), dtype=w.dtype, device=w.device)
|
||||
dx = torch.empty_like(dy)
|
||||
# enqueue kernel using forward pass heuristics
|
||||
# also compute partial sums for DW and DB
|
||||
x_arg = x.reshape(-1, x.shape[-1])
|
||||
M, N = x_arg.shape
|
||||
_layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks,
|
||||
x_arg.stride(0), N, ctx.eps,
|
||||
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
|
||||
GROUP_SIZE_M=GROUP_SIZE_M,
|
||||
num_warps=ctx.num_warps)
|
||||
_layer_norm_bwd_dx_fused[(M, )]( #
|
||||
dx, dy, _dw, _db, x, w, b, m, v, locks, #
|
||||
x_arg.stride(0), N, ctx.eps, #
|
||||
BLOCK_SIZE_N=ctx.BLOCK_SIZE, #
|
||||
GROUP_SIZE_M=GROUP_SIZE_M, #
|
||||
num_warps=ctx.num_warps)
|
||||
grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]
|
||||
# accumulate partial sums in separate kernel
|
||||
_layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N,
|
||||
BLOCK_SIZE_M=32,
|
||||
BLOCK_SIZE_N=128, num_ctas=1)
|
||||
_layer_norm_bwd_dwdb[grid](
|
||||
_dw, _db, dw, db, GROUP_SIZE_M, N, #
|
||||
BLOCK_SIZE_M=32, #
|
||||
BLOCK_SIZE_N=128, num_ctas=1)
|
||||
return dx, None, dw, db, None
|
||||
|
||||
|
||||
@@ -340,10 +338,16 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
|
||||
line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []),
|
||||
styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
|
||||
ylabel='GB/s',
|
||||
<<<<<<< HEAD
|
||||
plot_name='layer-norm-forward',
|
||||
args={'M': 4096, 'dtype': torch.float16, 'mode': 'forward'}
|
||||
)
|
||||
)
|
||||
=======
|
||||
plot_name='layer-norm-backward',
|
||||
args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'},
|
||||
))
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='cuda'):
|
||||
# create data
|
||||
x_shape = (M, N)
|
||||
@@ -356,24 +360,34 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='c
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
# utility functions
|
||||
if provider == 'triton':
|
||||
def y_fwd(): return layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704
|
||||
if provider == 'torch':
|
||||
def y_fwd(): return torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704
|
||||
if provider == 'apex':
|
||||
apex_layer_norm = apex.normalization.FusedLayerNorm(
|
||||
w_shape).to(x.device).to(x.dtype)
|
||||
|
||||
def y_fwd(): return apex_layer_norm(x) # noqa: F811, E704
|
||||
def y_fwd():
|
||||
return layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704
|
||||
|
||||
if provider == 'torch':
|
||||
|
||||
def y_fwd():
|
||||
return torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704
|
||||
|
||||
if provider == 'apex':
|
||||
apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)
|
||||
|
||||
def y_fwd():
|
||||
return apex_layer_norm(x) # noqa: F811, E704
|
||||
|
||||
# forward pass
|
||||
if mode == 'forward':
|
||||
gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500)
|
||||
# backward pass
|
||||
if mode == 'backward':
|
||||
def gbps(ms): return 3 * x.numel() * x.element_size() / ms * 1e-6 # noqa: F811, E704
|
||||
|
||||
def gbps(ms):
|
||||
return 3 * x.numel() * x.element_size() / ms * 1e-6 # noqa: F811, E704
|
||||
|
||||
y = y_fwd()
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),
|
||||
quantiles=quantiles, grad_to_none=[x], rep=500)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), quantiles=quantiles,
|
||||
grad_to_none=[x], rep=500)
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ if TORCH_HAS_FP8E5FNUZ:
|
||||
TORCH_HAS_FP8 = True
|
||||
|
||||
@triton.jit
|
||||
<<<<<<< HEAD
|
||||
def _attn_fwd_inner(
|
||||
acc, l_i, m_i, q,
|
||||
K_block_ptr, V_block_ptr,
|
||||
@@ -42,6 +43,14 @@ def _attn_fwd_inner(
|
||||
N_CTX,
|
||||
pre_load_v: tl.constexpr,
|
||||
):
|
||||
=======
|
||||
def _attn_fwd_inner(acc, l_i, m_i, q, #
|
||||
K_block_ptr, V_block_ptr, #
|
||||
start_m, qk_scale, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, #
|
||||
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, #
|
||||
N_CTX: tl.constexpr):
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
# range of values handled by this stage
|
||||
if STAGE == 1:
|
||||
lo, hi = 0, start_m * BLOCK_M
|
||||
@@ -83,6 +92,7 @@ def _attn_fwd_inner(
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||
return acc, l_i, m_i
|
||||
|
||||
|
||||
# We don't run auto-tuning everytime to keep the tutorial fast. Uncommenting
|
||||
# the code below and commenting out the equivalent parameters is convenient for
|
||||
# re-tuning.
|
||||
@@ -99,6 +109,7 @@ def _attn_fwd_inner(
|
||||
|
||||
|
||||
@triton.jit
|
||||
<<<<<<< HEAD
|
||||
def _attn_fwd(
|
||||
Q, K, V, sm_scale, M, Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
@@ -113,6 +124,20 @@ def _attn_fwd(
|
||||
BLOCK_N: tl.constexpr,
|
||||
pre_load_v: tl.constexpr,
|
||||
):
|
||||
=======
|
||||
def _attn_fwd(Q, K, V, sm_scale, M, Out, #
|
||||
stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vk, stride_vn, #
|
||||
stride_oz, stride_oh, stride_om, stride_on, #
|
||||
Z, H, #
|
||||
N_CTX: tl.constexpr, #
|
||||
BLOCK_M: tl.constexpr, #
|
||||
BLOCK_DMODEL: tl.constexpr, #
|
||||
BLOCK_N: tl.constexpr, #
|
||||
STAGE: tl.constexpr #
|
||||
):
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
qvk_offset = off_hz * stride_qh
|
||||
@@ -168,6 +193,7 @@ def _attn_fwd(
|
||||
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
|
||||
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
|
||||
if STAGE & 1:
|
||||
<<<<<<< HEAD
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
|
||||
start_m,
|
||||
@@ -175,11 +201,19 @@ def _attn_fwd(
|
||||
4 - STAGE, offs_m, offs_n,
|
||||
N_CTX, pre_load_v,
|
||||
)
|
||||
=======
|
||||
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #
|
||||
start_m, qk_scale, #
|
||||
BLOCK_M, BLOCK_DMODEL, BLOCK_N, #
|
||||
4 - STAGE, offs_m, offs_n, N_CTX #
|
||||
)
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
# stage 2: on-band
|
||||
if STAGE & 2:
|
||||
# barrier makes it easier for compielr to schedule the
|
||||
# two loops independently
|
||||
tl.debug_barrier()
|
||||
<<<<<<< HEAD
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
|
||||
start_m,
|
||||
@@ -187,6 +221,13 @@ def _attn_fwd(
|
||||
2, offs_m, offs_n,
|
||||
N_CTX, pre_load_v,
|
||||
)
|
||||
=======
|
||||
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #
|
||||
start_m, qk_scale, #
|
||||
BLOCK_M, BLOCK_DMODEL, BLOCK_N, #
|
||||
2, offs_m, offs_n, N_CTX #
|
||||
)
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
# epilogue
|
||||
# write back m
|
||||
acc = acc / l_i[:, None]
|
||||
@@ -197,8 +238,14 @@ def _attn_fwd(
|
||||
|
||||
@triton.jit
|
||||
def _attn_bwd_preprocess(O, DO, #
|
||||
<<<<<<< HEAD
|
||||
NewDO, Delta, #
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, #
|
||||
=======
|
||||
Delta, #
|
||||
Z, H, N_CTX, #
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr #
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
):
|
||||
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, D_HEAD)
|
||||
@@ -212,6 +259,7 @@ def _attn_bwd_preprocess(O, DO, #
|
||||
|
||||
|
||||
@triton.jit
|
||||
<<<<<<< HEAD
|
||||
def _bwd_kernel_dk_dv(
|
||||
Q, K, V, sm_scale, Out, DO,
|
||||
DK, DV,
|
||||
@@ -422,11 +470,242 @@ def _bwd_kernel_dq(
|
||||
order=(1, 0)
|
||||
)
|
||||
tl.store(DQ_block_ptr, (dq * sm_scale).to(tl.float16))
|
||||
=======
|
||||
def _attn_bwd_dkdv(dk, dv, #
|
||||
Q, k, v, sm_scale, #
|
||||
DO, #
|
||||
M, D, #
|
||||
# shared by Q/K/V/DO.
|
||||
stride_tok, stride_d, #
|
||||
H, N_CTX, BLOCK_M1: tl.constexpr, #
|
||||
BLOCK_N1: tl.constexpr, #
|
||||
BLOCK_DMODEL: tl.constexpr, #
|
||||
# Filled in by the wrapper.
|
||||
start_n, start_m, num_steps, #
|
||||
MASK: tl.constexpr):
|
||||
offs_m = start_m + tl.arange(0, BLOCK_M1)
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N1)
|
||||
offs_k = tl.arange(0, BLOCK_DMODEL)
|
||||
qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
|
||||
do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
|
||||
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
||||
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
||||
curr_m = start_m
|
||||
step_m = BLOCK_M1
|
||||
for blk_idx in range(num_steps):
|
||||
qT = tl.load(qT_ptrs)
|
||||
# Load m before computing qk to reduce pipeline stall.
|
||||
offs_m = curr_m + tl.arange(0, BLOCK_M1)
|
||||
m = tl.load(M + offs_m)
|
||||
qkT = tl.dot(k, qT)
|
||||
pT = tl.math.exp2(qkT - m[None, :])
|
||||
# Autoregressive masking.
|
||||
if MASK:
|
||||
mask = (offs_m[None, :] >= offs_n[:, None])
|
||||
pT = tl.where(mask, pT, 0.0)
|
||||
do = tl.load(do_ptrs)
|
||||
# Compute dV.
|
||||
ppT = pT
|
||||
ppT = ppT.to(tl.float16)
|
||||
dv += tl.dot(ppT, do)
|
||||
# D (= delta) is pre-divided by ds_scale.
|
||||
Di = tl.load(D + offs_m)
|
||||
# Compute dP and dS.
|
||||
dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
|
||||
dsT = pT * (dpT - Di[None, :])
|
||||
dsT = dsT.to(tl.float16)
|
||||
dk += tl.dot(dsT, tl.trans(qT))
|
||||
# Increment pointers.
|
||||
curr_m += step_m
|
||||
qT_ptrs += step_m * stride_tok
|
||||
do_ptrs += step_m * stride_tok
|
||||
return dk, dv
|
||||
|
||||
|
||||
# the main inner-loop logic for computing dQ
|
||||
@triton.jit
|
||||
def _attn_bwd_dq(dq, q, K, V, #
|
||||
do, m, D,
|
||||
# shared by Q/K/V/DO.
|
||||
stride_tok, stride_d, #
|
||||
H, N_CTX, #
|
||||
BLOCK_M2: tl.constexpr, #
|
||||
BLOCK_N2: tl.constexpr, #
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
# Filled in by the wrapper.
|
||||
start_m, start_n, num_steps, #
|
||||
MASK: tl.constexpr):
|
||||
offs_m = start_m + tl.arange(0, BLOCK_M2)
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N2)
|
||||
offs_k = tl.arange(0, BLOCK_DMODEL)
|
||||
kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
|
||||
vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
|
||||
# D (= delta) is pre-divided by ds_scale.
|
||||
Di = tl.load(D + offs_m)
|
||||
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
|
||||
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
|
||||
curr_n = start_n
|
||||
step_n = BLOCK_N2
|
||||
for blk_idx in range(num_steps):
|
||||
kT = tl.load(kT_ptrs)
|
||||
vT = tl.load(vT_ptrs)
|
||||
qk = tl.dot(q, kT)
|
||||
p = tl.math.exp2(qk - m)
|
||||
# Autoregressive masking.
|
||||
if MASK:
|
||||
offs_n = curr_n + tl.arange(0, BLOCK_N2)
|
||||
mask = (offs_m[:, None] >= offs_n[None, :])
|
||||
p = tl.where(mask, p, 0.0)
|
||||
# Compute dP and dS.
|
||||
dp = tl.dot(do, vT).to(tl.float32)
|
||||
ds = p * (dp - Di[:, None])
|
||||
ds = ds.to(tl.float16)
|
||||
# Compute dQ.
|
||||
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
|
||||
dq += tl.dot(ds, tl.trans(kT))
|
||||
# Increment pointers.
|
||||
curr_n += step_n
|
||||
kT_ptrs += step_n * stride_tok
|
||||
vT_ptrs += step_n * stride_tok
|
||||
return dq
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_bwd(Q, K, V, sm_scale, #
|
||||
DO, #
|
||||
DQ, DK, DV, #
|
||||
M, D,
|
||||
# shared by Q/K/V/DO.
|
||||
stride_z, stride_h, stride_tok, stride_d, #
|
||||
H, N_CTX, #
|
||||
BLOCK_M1: tl.constexpr, #
|
||||
BLOCK_N1: tl.constexpr, #
|
||||
BLOCK_M2: tl.constexpr, #
|
||||
BLOCK_N2: tl.constexpr, #
|
||||
BLK_SLICE_FACTOR: tl.constexpr, #
|
||||
BLOCK_DMODEL: tl.constexpr):
|
||||
LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
|
||||
|
||||
bhid = tl.program_id(2)
|
||||
off_chz = (bhid * N_CTX).to(tl.int64)
|
||||
adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
|
||||
pid = tl.program_id(0)
|
||||
|
||||
# offset pointers for batch/head
|
||||
Q += adj
|
||||
K += adj
|
||||
V += adj
|
||||
DO += adj
|
||||
DQ += adj
|
||||
DK += adj
|
||||
DV += adj
|
||||
M += off_chz
|
||||
D += off_chz
|
||||
|
||||
# load scales
|
||||
offs_k = tl.arange(0, BLOCK_DMODEL)
|
||||
|
||||
start_n = pid * BLOCK_N1
|
||||
start_m = start_n
|
||||
|
||||
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N1)
|
||||
|
||||
dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
# load K and V: they stay in SRAM throughout the inner loop.
|
||||
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
|
||||
num_steps = BLOCK_N1 // MASK_BLOCK_M1
|
||||
|
||||
dk, dv = _attn_bwd_dkdv(dk, dv, #
|
||||
Q, k, v, sm_scale, #
|
||||
DO, #
|
||||
M, D, #
|
||||
stride_tok, stride_d, #
|
||||
H, N_CTX, #
|
||||
MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, #
|
||||
start_n, start_m, num_steps, #
|
||||
MASK=True #
|
||||
)
|
||||
|
||||
start_m += num_steps * MASK_BLOCK_M1
|
||||
num_steps = (N_CTX - start_m) // BLOCK_M1
|
||||
|
||||
# Compute dK and dV for non-masked blocks.
|
||||
dk, dv = _attn_bwd_dkdv( #
|
||||
dk, dv, #
|
||||
Q, k, v, sm_scale, #
|
||||
DO, #
|
||||
M, D, #
|
||||
stride_tok, stride_d, #
|
||||
H, N_CTX, #
|
||||
BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, #
|
||||
start_n, start_m, num_steps, #
|
||||
MASK=False #
|
||||
)
|
||||
|
||||
dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
|
||||
tl.store(dv_ptrs, dv)
|
||||
|
||||
# Write back dK.
|
||||
dk *= sm_scale
|
||||
dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
|
||||
tl.store(dk_ptrs, dk)
|
||||
|
||||
# THIS BLOCK DOES DQ:
|
||||
start_m = pid * BLOCK_M2
|
||||
end_n = start_m + BLOCK_M2
|
||||
|
||||
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
|
||||
offs_m = start_m + tl.arange(0, BLOCK_M2)
|
||||
|
||||
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
|
||||
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
|
||||
m = tl.load(M + offs_m)
|
||||
m = m[:, None]
|
||||
|
||||
# Compute dQ for masked (diagonal) blocks.
|
||||
# NOTE: This code scans each row of QK^T backward (from right to left,
|
||||
# but inside each call to _attn_bwd_dq, from left to right), but that's
|
||||
# not due to anything important. I just wanted to reuse the loop
|
||||
# structure for dK & dV above as much as possible.
|
||||
num_steps = BLOCK_M2 // MASK_BLOCK_N2
|
||||
dq = _attn_bwd_dq(dq, q, K, V, #
|
||||
do, m, D, #
|
||||
stride_tok, stride_d, #
|
||||
H, N_CTX, #
|
||||
BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL, #
|
||||
start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, #
|
||||
MASK=True #
|
||||
)
|
||||
end_n -= num_steps * MASK_BLOCK_N2
|
||||
# stage 2
|
||||
num_steps = end_n // BLOCK_N2
|
||||
dq = _attn_bwd_dq(dq, q, K, V, #
|
||||
do, m, D, #
|
||||
stride_tok, stride_d, #
|
||||
H, N_CTX, #
|
||||
BLOCK_M2, BLOCK_N2, BLOCK_DMODEL, #
|
||||
start_m, end_n - num_steps * BLOCK_N2, num_steps, #
|
||||
MASK=False #
|
||||
)
|
||||
# Write back dQ.
|
||||
dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
|
||||
dq *= LN2
|
||||
tl.store(dq_ptrs, dq)
|
||||
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
|
||||
empty = torch.empty(128, device="cuda")
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False):
|
||||
# shape constraints
|
||||
@@ -453,6 +732,7 @@ class _attention(torch.autograd.Function):
|
||||
M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
|
||||
_attn_fwd[grid](
|
||||
<<<<<<< HEAD
|
||||
q, k, v, sm_scale, M, o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
@@ -462,6 +742,21 @@ class _attention(torch.autograd.Function):
|
||||
N_CTX=q.shape[2],
|
||||
BLOCK_DMODEL=Lk,
|
||||
STAGE=stage,
|
||||
=======
|
||||
q, k, v, sm_scale, M, o, #
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
|
||||
q.shape[0], q.shape[1], #
|
||||
N_CTX=q.shape[2], #
|
||||
BLOCK_M=BLOCK_M, #
|
||||
BLOCK_N=BLOCK_N, #
|
||||
BLOCK_DMODEL=Lk, #
|
||||
STAGE=stage, #
|
||||
num_warps=num_warps, #
|
||||
num_stages=num_stages #
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
)
|
||||
|
||||
## restore the grid for bwd kernel
|
||||
@@ -493,6 +788,7 @@ class _attention(torch.autograd.Function):
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
BATCH, N_HEAD, N_CTX = q.shape[:3]
|
||||
<<<<<<< HEAD
|
||||
delta = torch.empty_like(L)
|
||||
do_scaled = torch.empty_like(do)
|
||||
# Figure out what BLOCK size fwd used and adjust num_blocks accordingly.
|
||||
@@ -506,6 +802,39 @@ class _attention(torch.autograd.Function):
|
||||
o, do, #
|
||||
do_scaled, delta, #
|
||||
BLOCK_M=block_scale * BLOCK, D_HEAD=ctx.BLOCK_DMODEL, #
|
||||
=======
|
||||
PRE_BLOCK = 128
|
||||
NUM_WARPS, NUM_STAGES = 4, 1
|
||||
if torch.cuda.get_device_capability()[0] == 9:
|
||||
NUM_STAGES = 5
|
||||
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
|
||||
BLK_SLICE_FACTOR = 2
|
||||
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
|
||||
arg_k = k
|
||||
arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
|
||||
PRE_BLOCK = 128
|
||||
assert N_CTX % PRE_BLOCK == 0
|
||||
pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
|
||||
delta = torch.empty_like(M)
|
||||
_attn_bwd_preprocess[pre_grid](
|
||||
o, do, #
|
||||
delta, #
|
||||
BATCH, N_HEAD, N_CTX, #
|
||||
BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL #
|
||||
)
|
||||
grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
|
||||
_attn_bwd[grid](
|
||||
q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, #
|
||||
M, delta, #
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
|
||||
N_HEAD, N_CTX, #
|
||||
BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, #
|
||||
BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, #
|
||||
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, #
|
||||
num_warps=NUM_WARPS, #
|
||||
num_stages=NUM_STAGES #
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
)
|
||||
if not ctx.split_kernel:
|
||||
_bwd_kernel[(ctx.grid[1],)](
|
||||
@@ -599,11 +928,17 @@ def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
|
||||
])
|
||||
def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
<<<<<<< HEAD
|
||||
causal = True
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
|
||||
=======
|
||||
q = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
|
||||
k = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
|
||||
v = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
sm_scale = 0.5
|
||||
split_kernel = True
|
||||
dout = torch.randn_like(q)
|
||||
@@ -672,17 +1007,26 @@ for mode in ['fwd', 'bwd']:
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-{mode}-d{D_HEAD}-causal={causal}',
|
||||
args={
|
||||
<<<<<<< HEAD
|
||||
'D_HEAD': D_HEAD,
|
||||
'dtype': torch.float16,
|
||||
'mode': mode,
|
||||
'causal': causal})
|
||||
)
|
||||
=======
|
||||
"H": N_HEADS,
|
||||
"BATCH": BATCH,
|
||||
"D_HEAD": D_HEAD,
|
||||
"dtype": torch.float16,
|
||||
"mode": mode,
|
||||
"causal": causal,
|
||||
},
|
||||
))
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def bench_flash_attention(
|
||||
BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"
|
||||
):
|
||||
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"):
|
||||
assert mode in ["fwd", "bwd"]
|
||||
warmup = 25
|
||||
rep = 100
|
||||
@@ -706,9 +1050,7 @@ def bench_flash_attention(
|
||||
fn = lambda: o.backward(do, retain_graph=True)
|
||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||
if provider == "flash":
|
||||
qkv = torch.randn(
|
||||
(BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True
|
||||
)
|
||||
qkv = torch.randn((BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
|
||||
fn = lambda: flash_attn_func(qkv, causal=causal)
|
||||
if mode == "bwd":
|
||||
o = fn()
|
||||
|
||||
@@ -22,10 +22,10 @@ import triton.language as tl
|
||||
|
||||
@triton.jit
|
||||
def asin_kernel(
|
||||
x_ptr,
|
||||
y_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
x_ptr,
|
||||
y_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
@@ -35,12 +35,12 @@ def asin_kernel(
|
||||
x = tl.math.asin(x)
|
||||
tl.store(y_ptr + offsets, x, mask=mask)
|
||||
|
||||
|
||||
# %%
|
||||
# Using the default libdevice library path
|
||||
# -----------------------------------------
|
||||
# We can use the default libdevice library path encoded in `triton/language/math.py`
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
size = 98432
|
||||
x = torch.rand(size, device='cuda')
|
||||
@@ -48,14 +48,12 @@ output_triton = torch.zeros(size, device='cuda')
|
||||
output_torch = torch.asin(x)
|
||||
assert x.is_cuda and output_triton.is_cuda
|
||||
n_elements = output_torch.numel()
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
|
||||
asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024)
|
||||
print(output_torch)
|
||||
print(output_triton)
|
||||
print(
|
||||
f'The maximum difference between torch and triton is '
|
||||
f'{torch.max(torch.abs(output_torch - output_triton))}'
|
||||
)
|
||||
print(f'The maximum difference between torch and triton is '
|
||||
f'{torch.max(torch.abs(output_torch - output_triton))}')
|
||||
|
||||
# %%
|
||||
# Customize the libdevice library path
|
||||
@@ -67,7 +65,5 @@ asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024,
|
||||
extern_libs={'libdevice': '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'})
|
||||
print(output_torch)
|
||||
print(output_triton)
|
||||
print(
|
||||
f'The maximum difference between torch and triton is '
|
||||
f'{torch.max(torch.abs(output_torch - output_triton))}'
|
||||
)
|
||||
print(f'The maximum difference between torch and triton is '
|
||||
f'{torch.max(torch.abs(output_torch - output_triton))}')
|
||||
|
||||
@@ -98,14 +98,22 @@ import triton.language as tl
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
||||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
|
||||
num_warps=2),
|
||||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
|
||||
num_warps=2),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
@@ -118,13 +126,11 @@ def matmul_kernel_with_block_pointers(
|
||||
# The stride variables represent how much to increase the ptr by when moving by 1
|
||||
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
|
||||
# by to get the element one row down (A has M rows).
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr
|
||||
):
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
|
||||
"""Kernel for computing the matmul C = A x B.
|
||||
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
||||
"""
|
||||
@@ -196,16 +202,13 @@ def matmul(a, b):
|
||||
# Allocates output.
|
||||
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
|
||||
# 1D launch kernel where each block gets its own program.
|
||||
grid = lambda META: (
|
||||
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
|
||||
)
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
|
||||
matmul_kernel_with_block_pointers[grid](
|
||||
a, b, c,
|
||||
M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
)
|
||||
a, b, c, #
|
||||
M, N, K, #
|
||||
a.stride(0), a.stride(1), #
|
||||
b.stride(0), b.stride(1), #
|
||||
c.stride(0), c.stride(1))
|
||||
return c
|
||||
|
||||
|
||||
|
||||
@@ -40,23 +40,24 @@ if torch.cuda.get_device_capability()[0] < 9:
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7,
|
||||
num_warps=4),
|
||||
# triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4, num_ctas=2),
|
||||
# triton.Config({'BLOCK_SIZE_M': 512, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4, num_ctas=4),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
a_ptr, b_ptr, z_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_zm, stride_zn,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
|
||||
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr,
|
||||
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr
|
||||
):
|
||||
def matmul_kernel(a_ptr, b_ptr, z_ptr, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_zm, stride_zn, #
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr, #
|
||||
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, #
|
||||
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr #
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
@@ -70,9 +71,11 @@ def matmul_kernel(
|
||||
block_offset_n = pid_n * BLOCK_SIZE_N
|
||||
|
||||
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
|
||||
offsets=(block_offset_m, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), order=(A_ORDER_0, A_ORDER_1))
|
||||
offsets=(block_offset_m, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),
|
||||
order=(A_ORDER_0, A_ORDER_1))
|
||||
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
|
||||
offsets=(0, block_offset_n), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), order=(B_ORDER_0, B_ORDER_1))
|
||||
offsets=(0, block_offset_n), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),
|
||||
order=(B_ORDER_0, B_ORDER_1))
|
||||
z = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
offs_m = block_offset_m + tl.arange(0, BLOCK_SIZE_M)
|
||||
@@ -101,15 +104,17 @@ def matmul(a, b, a_order, b_order):
|
||||
z = torch.empty((M, N), device=a.device, dtype=torch.float16)
|
||||
|
||||
def grid(META):
|
||||
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
|
||||
matmul_kernel[grid](a_ptr=a, b_ptr=b, z_ptr=z,
|
||||
M=M, N=N, K=K,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_zm=z.stride(0), stride_zn=z.stride(1),
|
||||
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1],
|
||||
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1]
|
||||
)
|
||||
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
|
||||
|
||||
matmul_kernel[grid](
|
||||
a_ptr=a, b_ptr=b, z_ptr=z, #
|
||||
M=M, N=N, K=K, #
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1), #
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1), #
|
||||
stride_zm=z.stride(0), stride_zn=z.stride(1), #
|
||||
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
|
||||
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1] #
|
||||
)
|
||||
return z
|
||||
|
||||
|
||||
@@ -160,14 +165,12 @@ def test_matmul():
|
||||
# label name for the lines
|
||||
line_names=["cuBLAS", "Triton"],
|
||||
# line styles
|
||||
styles=[('green', '-'), ('green', '--'),
|
||||
('blue', '-'), ('blue', '--')],
|
||||
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
|
||||
ylabel="TFLOPS", # label name for the y-axis
|
||||
plot_name="matmul-performance",
|
||||
# name for the plot. Used also as a file name for saving the plot.
|
||||
args={},
|
||||
)
|
||||
)
|
||||
))
|
||||
def benchmark(M, N, K, TRANS_A, TRANS_B, provider):
|
||||
if (TRANS_A):
|
||||
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
|
||||
@@ -185,14 +188,15 @@ def benchmark(M, N, K, TRANS_A, TRANS_B, provider):
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == 'cublas':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: torch.matmul(a, b), rep=100, quantiles=quantiles, fast_flush=False)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), rep=100, quantiles=quantiles,
|
||||
fast_flush=False)
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: matmul(a, b, a_order, b_order), rep=100, quantiles=quantiles, fast_flush=False)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, a_order, b_order), rep=100,
|
||||
quantiles=quantiles, fast_flush=False)
|
||||
|
||||
def perf(ms):
|
||||
return 2 * M * N * K * 1e-12 / (ms * 1e-3)
|
||||
|
||||
return perf(ms), perf(max_ms), perf(min_ms)
|
||||
|
||||
|
||||
|
||||
@@ -40,21 +40,21 @@ if torch.cuda.get_device_capability()[0] < 9:
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7,
|
||||
num_warps=4),
|
||||
# triton.Config({'BLOCK_SIZE_M': 512, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4, num_ctas=4),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
):
|
||||
def matmul_kernel(a_ptr, b_ptr, c_ptr, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
|
||||
GROUP_SIZE_M: tl.constexpr #
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
@@ -67,20 +67,10 @@ def matmul_kernel(
|
||||
block_offset_m = pid_m * BLOCK_SIZE_M
|
||||
block_offset_n = pid_n * BLOCK_SIZE_N
|
||||
|
||||
a_tile_ptr = tl.make_block_ptr(
|
||||
base=a_ptr, shape=(
|
||||
M, K), strides=(
|
||||
stride_am, stride_ak), offsets=(
|
||||
block_offset_m, 0), block_shape=(
|
||||
BLOCK_SIZE_M, BLOCK_SIZE_K), order=(
|
||||
1, 0))
|
||||
b_tile_ptr = tl.make_block_ptr(
|
||||
base=b_ptr, shape=(
|
||||
K, N), strides=(
|
||||
stride_bk, stride_bn), offsets=(
|
||||
0, block_offset_n), block_shape=(
|
||||
BLOCK_SIZE_K, BLOCK_SIZE_N), order=(
|
||||
0, 1))
|
||||
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
|
||||
offsets=(block_offset_m, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), order=(1, 0))
|
||||
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
|
||||
offsets=(0, block_offset_n), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), order=(0, 1))
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
for k in range(0, K, BLOCK_SIZE_K):
|
||||
@@ -91,7 +81,8 @@ def matmul_kernel(
|
||||
b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_SIZE_K, 0])
|
||||
|
||||
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
|
||||
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0))
|
||||
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N),
|
||||
order=(1, 0))
|
||||
|
||||
tl.store(c_block_ptr, accumulator)
|
||||
|
||||
@@ -101,20 +92,19 @@ def matmul(a, b):
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
M, K = a.shape
|
||||
K, N = b.shape
|
||||
assert (
|
||||
K % 32 == 0
|
||||
), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K"
|
||||
assert (K % 32 == 0), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K"
|
||||
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
||||
|
||||
def grid(META):
|
||||
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
|
||||
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
|
||||
|
||||
matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
|
||||
M=M, N=N, K=K,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1))
|
||||
matmul_kernel[grid](
|
||||
a_ptr=a, b_ptr=b, c_ptr=c, #
|
||||
M=M, N=N, K=K, #
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1), #
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1), #
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1))
|
||||
return c
|
||||
|
||||
|
||||
@@ -126,12 +116,7 @@ c = torch.nn.functional.normalize(c)
|
||||
golden = torch.nn.functional.normalize(torch.matmul(a, b))
|
||||
|
||||
torch.set_printoptions(profile="full")
|
||||
assert_close(
|
||||
c,
|
||||
golden,
|
||||
rtol=1e-2,
|
||||
atol=1e-3,
|
||||
check_dtype=False)
|
||||
assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
@@ -143,7 +128,7 @@ assert_close(
|
||||
[2048, 1024, 1024],
|
||||
[2048, 2048, 2048],
|
||||
[2048, 4096, 4096],
|
||||
[2048, 8192, 8192]
|
||||
[2048, 8192, 8192],
|
||||
], # different possible values for `x_name`
|
||||
line_arg='provider',
|
||||
# argument name whose value corresponds to a different line in the plot
|
||||
@@ -152,27 +137,26 @@ assert_close(
|
||||
# label name for the lines
|
||||
line_names=["cuBLAS", "Triton"],
|
||||
# line styles
|
||||
styles=[('green', '-'), ('green', '--'),
|
||||
('blue', '-'), ('blue', '--')],
|
||||
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
|
||||
ylabel="TFLOPS", # label name for the y-axis
|
||||
plot_name="matmul-performance",
|
||||
# name for the plot. Used also as a file name for saving the plot.
|
||||
args={},
|
||||
)
|
||||
)
|
||||
))
|
||||
def benchmark(M, N, K, provider):
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == 'cublas':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: torch.matmul(a, b), rep=100, quantiles=quantiles, fast_flush=False)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), rep=100, quantiles=quantiles,
|
||||
fast_flush=False)
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: matmul(a, b), rep=100, quantiles=quantiles, fast_flush=False)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), rep=100, quantiles=quantiles,
|
||||
fast_flush=False)
|
||||
|
||||
def perf(ms):
|
||||
return 2 * M * N * K * 1e-12 / (ms * 1e-3)
|
||||
|
||||
return perf(ms), perf(max_ms), perf(min_ms)
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,13 @@
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
"""
|
||||
Group GEMM
|
||||
============================
|
||||
This group gemm kernel launches a fixed number of CTA to compute a group
|
||||
of gemms. The scheduling is static and we do it on device.
|
||||
"""
|
||||
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
@@ -28,6 +38,7 @@ import triton.language as tl
|
||||
# of gemms. The scheduling is static and we do it on device
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
<<<<<<< HEAD
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
@@ -111,6 +122,32 @@ import triton.language as tl
|
||||
num_stages = 0,
|
||||
num_warps = 2,
|
||||
),
|
||||
=======
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'NUM_SM': 84,
|
||||
}),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'NUM_SM': 128,
|
||||
}),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'NUM_SM': 84,
|
||||
}),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'NUM_SM': 128,
|
||||
}),
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
],
|
||||
key=['SUM_M', 'SUM_N', 'SUM_K'],
|
||||
)
|
||||
@@ -149,9 +186,7 @@ def grouped_matmul_kernel(
|
||||
num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)
|
||||
num_tiles = num_m_tiles * num_n_tiles
|
||||
# iterate through the tiles in the current gemm problem
|
||||
while (
|
||||
tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles
|
||||
):
|
||||
while (tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles):
|
||||
# pick up a tile from the current gemm problem
|
||||
k = gk
|
||||
lda = tl.load(g_lds + g * 3)
|
||||
@@ -171,9 +206,7 @@ def grouped_matmul_kernel(
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :]
|
||||
b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :]
|
||||
accumulator = tl.zeros(
|
||||
(BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32
|
||||
)
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):
|
||||
# hint to Triton compiler to do proper loop pipelining
|
||||
tl.multiple_of(a_ptrs, [16, 16])
|
||||
@@ -224,7 +257,7 @@ def group_gemm_fn(group_A, group_B):
|
||||
group_C.append(C)
|
||||
A_addrs.append(A.data_ptr())
|
||||
B_addrs.append(B.data_ptr())
|
||||
C_addrs .append(C.data_ptr())
|
||||
C_addrs.append(C.data_ptr())
|
||||
g_sizes += [M, N, K]
|
||||
SUM_M += M
|
||||
SUM_N += N
|
||||
@@ -235,14 +268,10 @@ def group_gemm_fn(group_A, group_B):
|
||||
d_a_ptrs = torch.tensor(A_addrs, device=device)
|
||||
d_b_ptrs = torch.tensor(B_addrs, device=device)
|
||||
d_c_ptrs = torch.tensor(C_addrs, device=device)
|
||||
d_g_sizes = torch.tensor(
|
||||
g_sizes, dtype=torch.int32, device=device
|
||||
)
|
||||
d_g_lds = torch.tensor(
|
||||
g_lds, dtype=torch.int32, device=device
|
||||
)
|
||||
d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=device)
|
||||
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=device)
|
||||
# we use a fixed number of CTA, and it's auto-tunable
|
||||
grid = lambda META: (META['NUM_SM'],)
|
||||
grid = lambda META: (META['NUM_SM'], )
|
||||
grouped_matmul_kernel[grid](
|
||||
d_a_ptrs,
|
||||
d_b_ptrs,
|
||||
@@ -283,8 +312,13 @@ for i in range(group_size):
|
||||
|
||||
|
||||
# only launch the kernel, no tensor preparation here to remove all overhead
|
||||
<<<<<<< HEAD
|
||||
def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size, sum_m, sum_n, sum_k):
|
||||
grid = lambda META: (META['NUM_SM'],)
|
||||
=======
|
||||
def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size):
|
||||
grid = lambda META: (META['NUM_SM'], )
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
grouped_matmul_kernel[grid](
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
@@ -307,7 +341,7 @@ def torch_perf_fn(group_A, group_B):
|
||||
triton.testing.Benchmark(
|
||||
# argument names to use as an x-axis for the plot
|
||||
x_names=['N'],
|
||||
x_vals=[2 ** i for i in range(7, 11)], # different possible values for `x_name`
|
||||
x_vals=[2**i for i in range(7, 11)], # different possible values for `x_name`
|
||||
line_arg='provider',
|
||||
# argument name whose value corresponds to a different line in the plot
|
||||
# possible values for `line_arg``
|
||||
@@ -320,8 +354,7 @@ def torch_perf_fn(group_A, group_B):
|
||||
plot_name="group-gemm-performance",
|
||||
# name for the plot. Used also as a file name for saving the plot.
|
||||
args={},
|
||||
)
|
||||
)
|
||||
))
|
||||
def benchmark(N, provider):
|
||||
group_size = 4
|
||||
group_A = []
|
||||
@@ -341,7 +374,7 @@ def benchmark(N, provider):
|
||||
group_C.append(C)
|
||||
A_addrs.append(A.data_ptr())
|
||||
B_addrs.append(B.data_ptr())
|
||||
C_addrs .append(C.data_ptr())
|
||||
C_addrs.append(C.data_ptr())
|
||||
g_sizes += [N, N, N]
|
||||
g_lds += [N, N, N]
|
||||
|
||||
@@ -355,7 +388,12 @@ def benchmark(N, provider):
|
||||
if provider == 'cublas':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_perf_fn(group_A, group_B), quantiles=quantiles)
|
||||
if provider == 'triton':
|
||||
<<<<<<< HEAD
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_perf_fn(d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size, group_size*N, group_size*N, group_size*N), quantiles=quantiles)
|
||||
=======
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: triton_perf_fn(d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size), quantiles=quantiles)
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
return ms, max_ms, min_ms
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user