Merge commit 'ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33' into ifu-rebase-again

Conflicts:
	.gitignore
	.gitmodules
	README.md
	bin/triton-translate.cpp
	include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
	include/triton/Target/AMDGCN/AMDGCNTranslation.h
	include/triton/Target/HSACO/HSACOTranslation.h
	lib/Analysis/Allocation.cpp
	lib/Analysis/Utility.cpp
	lib/Conversion/TritonGPUToLLVM/CMakeLists.txt
	lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/Utility.cpp
	lib/Conversion/TritonGPUToLLVM/Utility.h
	lib/Dialect/TritonGPU/IR/Dialect.cpp
	lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
	lib/Target/HSACO/CMakeLists.txt
	lib/Target/HSACO/HSACOTranslation.cpp
	lib/Target/LLVMIR/LLVMIRTranslation.cpp
	python/src/triton.cc
	python/test/unit/language/test_core.py
	python/test/unit/operators/test_flash_attention.py
	python/triton/compiler/compiler.py
	python/triton/compiler/make_launcher.py
	python/triton/language/semantic.py
	python/triton/runtime/jit.py
	python/tutorials/06-fused-attention.py
	python/tutorials/11-grouped-gemm.py
	test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
Jason Furmanek
2023-11-06 23:10:10 +00:00
161 changed files with 6530 additions and 3905 deletions

View File

@@ -101,20 +101,34 @@ def get_backend(device_type: str):
return _backends[device_type] if device_type in _backends else None
@functools.lru_cache()
def path_to_ptxas():
def _path_to_binary(binary: str):
base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
paths = [
os.environ.get("TRITON_PTXAS_PATH", ""),
os.path.join(base_dir, "third_party", "cuda", "bin", "ptxas")
os.path.join(base_dir, "third_party", "cuda", "bin", binary)
]
for ptxas in paths:
ptxas_bin = ptxas.split(" ")[0]
if os.path.exists(ptxas_bin) and os.path.isfile(ptxas_bin):
result = subprocess.check_output([ptxas_bin, "--version"], stderr=subprocess.STDOUT)
for p in paths:
bin = p.split(" ")[0]
if os.path.exists(bin) and os.path.isfile(bin):
result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT)
if result is not None:
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
if version is not None:
return ptxas, version.group(1)
raise RuntimeError("Cannot find ptxas")
return p, version.group(1)
raise RuntimeError(f"Cannot find {binary}")
@functools.lru_cache()
def path_to_ptxas():
return _path_to_binary("ptxas")
@functools.lru_cache()
def path_to_cuobjdump():
return _path_to_binary("cuobjdump")
@functools.lru_cache()
def path_to_nvdisasm():
return _path_to_binary("nvdisasm")

View File

@@ -199,7 +199,7 @@ class ContainsReturnChecker(ast.NodeVisitor):
class CodeGenerator(ast.NodeVisitor):
def __init__(self, context, prototype, gscope, attributes, constants, function_name, arch,
def __init__(self, context, prototype, gscope, attributes, constants, function_name, target,
module=None, is_kernel=False, function_types: Optional[Dict] = None,
debug=False, noinline=False, file_name: Optional[str] = None, begin_line=0):
self.context = context
@@ -208,7 +208,7 @@ class CodeGenerator(ast.NodeVisitor):
# node.lineno starts from 1, so we need to subtract 1
self.begin_line = begin_line - 1
self.builder.set_loc(file_name, begin_line, 0)
self.builder.arch = arch
self.builder.target = target
self.module = self.builder.create_module() if module is None else module
self.function_ret_types = {} if function_types is None else function_types
self.prototype = prototype
@@ -912,7 +912,7 @@ class CodeGenerator(ast.NodeVisitor):
file_name, begin_line = _get_fn_file_line(fn)
generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module,
function_name=fn_name, function_types=self.function_ret_types, debug=debug, noinline=fn.noinline,
file_name=file_name, begin_line=begin_line, arch=self.builder.arch)
file_name=file_name, begin_line=begin_line, target=self.builder.target)
generator.visit(fn.parse())
callee_ret_type = generator.last_ret_type
self.function_ret_types[fn_name] = callee_ret_type
@@ -1108,7 +1108,7 @@ def kernel_suffix(signature, specialization):
return suffix
def ast_to_ttir(fn, signature, specialization, constants, debug, arch):
def ast_to_ttir(fn, signature, specialization, constants, debug, target):
# canonicalize signature
if isinstance(signature, str):
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
@@ -1137,7 +1137,7 @@ def ast_to_ttir(fn, signature, specialization, constants, debug, arch):
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants,
function_name=function_name, attributes=new_attrs,
is_kernel=True, debug=debug, file_name=file_name, begin_line=begin_line,
arch=arch)
target=target)
try:
generator.visit(fn.parse())
except CompilationError as e:

View File

@@ -5,10 +5,18 @@ import hashlib
import json
import os
import re
<<<<<<< HEAD
import tempfile
from collections import namedtuple
from pathlib import Path
from typing import Any
=======
from collections import namedtuple
from pathlib import Path
from typing import Any
from dataclasses import dataclass
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs,
compile_ptx_to_cubin, get_env_vars, get_num_warps,
@@ -20,11 +28,11 @@ from ..common.build import is_hip
# from ..runtime import driver, jit, JITFunction
# TODO: runtime.errors
from ..runtime.autotuner import OutOfResources
from ..runtime.cache import get_cache_manager
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
from ..runtime.driver import driver
from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device,
get_device_capability, version_key)
from ..tools.disasm import extract
from ..tools.disasm import get_sass
from .code_generator import ast_to_ttir
from .make_launcher import make_stub
from .utils import (InfoFromBackendForTensorMap, TensorMapManager,
@@ -32,6 +40,24 @@ from .utils import (InfoFromBackendForTensorMap, TensorMapManager,
CUDA_DEFAULT_WARP_SIZE = 32
@dataclass
class CudaTargetDescriptor:
capability: int
num_warps: int
def _is_cuda(target):
return isinstance(target, CudaTargetDescriptor)
class LazyDict(dict):
def __getitem__(self, key):
val = dict.__getitem__(self, key)
if callable(val):
return val()
return val
def inline_triton_ir(mod):
pm = ir.pass_manager(mod.context)
pm.enable_debug()
@@ -40,11 +66,12 @@ def inline_triton_ir(mod):
return mod
def ttir_compute_capability_rewrite(mod, arch):
def ttir_compute_capability_rewrite(mod, target):
# For hardware without support, we must rewrite all load/store
# with block (tensor) pointers into tensors of pointers
pm = ir.pass_manager(mod.context)
pm.enable_debug()
<<<<<<< HEAD
if _is_cuda(arch):
pm.add_rewrite_tensor_pointer_pass(arch, False)
elif is_hip():
@@ -52,13 +79,17 @@ def ttir_compute_capability_rewrite(mod, arch):
pm.add_rewrite_tensor_pointer_pass(capability, True)
else:
assert(False, "unsupported target")
=======
if _is_cuda(target):
pm.add_rewrite_tensor_pointer_pass(target.capability)
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
pm.run(mod)
return mod
def optimize_ttir(mod, arch):
def optimize_ttir(mod, target):
mod = inline_triton_ir(mod)
mod = ttir_compute_capability_rewrite(mod, arch)
mod = ttir_compute_capability_rewrite(mod, target)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_inliner_pass()
@@ -72,6 +103,7 @@ def optimize_ttir(mod, arch):
return mod
<<<<<<< HEAD
def ttir_to_ttgir(mod, num_warps, warpsize, num_ctas, arch):
pm = ir.pass_manager(mod.context)
pm.enable_debug()
@@ -79,21 +111,36 @@ def ttir_to_ttgir(mod, num_warps, warpsize, num_ctas, arch):
pm.add_convert_triton_to_tritongpu_pass(num_warps, warpsize, num_ctas, 0)
else:
pm.add_convert_triton_to_tritongpu_pass(num_warps, warpsize, num_ctas, arch)
=======
def ttir_to_ttgir(mod, num_warps, num_ctas, target):
pm = ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_convert_triton_to_tritongpu_pass(num_warps, 32, num_ctas, target.capability)
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
pm.run(mod)
return mod
<<<<<<< HEAD
def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch,
cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, matrix_inst_type):
=======
def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target,
cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue):
is_cuda = _is_cuda(target)
if is_cuda:
capability = target.capability
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
pm = ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_tritongpu_coalesce_pass()
# TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass
pm.add_plan_cta_pass(cluster_info)
if _is_cuda(arch):
pm.add_tritongpu_rewrite_tensor_pointer_pass(arch)
if is_cuda:
pm.add_tritongpu_rewrite_tensor_pointer_pass(capability)
pm.add_plan_cta_pass(cluster_info)
pm.add_tritongpu_remove_layout_conversions_pass()
<<<<<<< HEAD
if _is_cuda(arch):
pm.add_tritongpu_accelerate_matmul_pass(arch)
# TODO change interface of accelerate_matmul_pass
@@ -101,6 +148,10 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch,
matrix_core_version = gpu_matrix_core_version()
matrix_inst_size = matrix_inst_type
pm.add_tritonamdgpu_accelerate_matmul_pass(matrix_core_version, matrix_inst_size)
=======
if is_cuda:
pm.add_tritongpu_accelerate_matmul_pass(capability)
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
pm.add_tritongpu_remove_layout_conversions_pass()
if optimize_epilogue:
pm.add_tritongpu_optimize_epilogue_pass()
@@ -114,20 +165,25 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch,
# it's the responsibility of the compiler to figure out the exact
# `num_warps` to use.
# TODO: support the case where `num_warps` from user is not 4.
<<<<<<< HEAD
if _is_cuda(arch) and arch // 10 >= 9 and enable_warp_specialization and num_warps == 4:
pm.add_tritongpu_ws_feasibility_checking_pass(arch)
=======
if capability // 10 >= 9 and enable_warp_specialization and num_warps == 4:
pm.add_tritongpu_ws_feasibility_checking_pass(capability)
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
pm.run(mod)
ws_enabled = ir.is_ws_supported(mod)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
if ws_enabled:
pm.add_tritongpu_wsdecomposing_pass(arch)
pm.add_tritongpu_wspipeline_pass(
num_stages, num_warps, arch)
pm.add_tritongpu_wsmutex_pass(arch)
pm.add_tritongpu_wsmaterialization_pass(arch)
pm.add_tritongpu_wsdecomposing_pass(capability)
pm.add_tritongpu_wspipeline_pass(num_stages, num_warps, capability)
pm.add_tritongpu_wsmutex_pass(capability)
pm.add_tritongpu_wsmaterialization_pass(capability)
pm.add_cse_pass()
else:
<<<<<<< HEAD
if is_hip():
pm.add_tritongpu_pipeline_pass(
num_stages, num_warps, num_ctas, 0)
@@ -139,6 +195,11 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch,
else:
pm.add_tritongpu_materialize_load_store_pass(num_warps, arch)
if _is_cuda(arch) and arch // 10 <= 8:
=======
pm.add_tritongpu_pipeline_pass(num_stages, num_warps, num_ctas, capability)
pm.add_tritongpu_materialize_load_store_pass(num_warps, capability)
if capability // 10 <= 8:
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
pm.add_tritongpu_prefetch_pass()
pm.add_tritongpu_optimize_dot_operands_pass()
pm.add_tritongpu_remove_layout_conversions_pass()
@@ -148,7 +209,11 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch,
pm.add_tritongpu_reorder_instructions_pass()
pm.add_cse_pass()
pm.add_symbol_dce_pass()
<<<<<<< HEAD
if _is_cuda(arch) and arch // 10 >= 9:
=======
if capability // 10 >= 9:
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
pm.add_tritongpu_fence_insertion_pass()
pm.add_tritongpu_ws_fixup_missing_attrs_pass()
pm.run(mod)
@@ -162,12 +227,21 @@ def _add_external_libs(mod, libs):
add_external_libs(mod, list(libs.keys()), list(libs.values()))
<<<<<<< HEAD
def ttgir_to_llir(mod, extern_libs, arch, tma_infos, waves_per_eu=0):
if extern_libs:
_add_external_libs(mod, extern_libs)
# TODO: separate tritongpu_to_llvmir for different backends
if _is_cuda(arch):
return translate_triton_gpu_to_llvmir(mod, arch, tma_infos, runtime.TARGET.NVVM, waves_per_eu)
=======
def ttgir_to_llir(mod, extern_libs, target, tma_infos):
if extern_libs:
_add_external_libs(mod, extern_libs)
# TODO: separate tritongpu_to_llvmir for different backends
if _is_cuda(target):
return translate_triton_gpu_to_llvmir(mod, target.capability, tma_infos, runtime.TARGET.NVVM)
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
else:
return translate_triton_gpu_to_llvmir(mod, 0, TMAInfos(), runtime.TARGET.ROCDL, waves_per_eu)
@@ -190,7 +264,7 @@ def ptx_get_version(cuda_version) -> int:
raise RuntimeError("Triton only support CUDA 10.0 or higher")
def llir_to_ptx(mod: Any, arch: int, ptx_version: int = None) -> str:
def llir_to_ptx(mod: Any, target: CudaTargetDescriptor, ptx_version: int = None) -> str:
'''
Translate TritonGPU module to PTX code.
:param mod: a TritonGPU dialect module
@@ -199,10 +273,10 @@ def llir_to_ptx(mod: Any, arch: int, ptx_version: int = None) -> str:
if ptx_version is None:
_, cuda_version = path_to_ptxas()
ptx_version = ptx_get_version(cuda_version)
return translate_llvmir_to_ptx(mod, arch, ptx_version)
return translate_llvmir_to_ptx(mod, target.capability, ptx_version)
def ptx_to_cubin(ptx: str, arch: int):
def ptx_to_cubin(ptx: str, target: CudaTargetDescriptor):
'''
Compile TritonGPU module to cubin.
:param ptx: ptx code
@@ -210,7 +284,11 @@ def ptx_to_cubin(ptx: str, arch: int):
:return: str
'''
ptxas, _ = path_to_ptxas()
<<<<<<< HEAD
return compile_ptx_to_cubin(ptx, ptxas, arch)
=======
return compile_ptx_to_cubin(ptx, ptxas, target.capability)
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
# ------------------------------------------------------------------------------
@@ -230,13 +308,15 @@ def get_kernel_name(src: str, pattern: str) -> str:
def convert_type_repr(x):
match = re.search(r'!tt\.ptr<(.*)>', x)
# Currently we only capture the pointer type and assume the pointer is on global memory.
# TODO: Capture and support shared memory space
match = re.search(r'!tt\.ptr<([^,]+)', x)
if match is not None:
return '*' + convert_type_repr(match.group(1))
return x
def make_hash(fn, arch, env_vars, **kwargs):
def make_hash(fn, target, env_vars, **kwargs):
if isinstance(fn, JITFunction):
configs = kwargs["configs"]
signature = kwargs["signature"]
@@ -253,9 +333,16 @@ def make_hash(fn, arch, env_vars, **kwargs):
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1), sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8))
configs_key = [get_conf_key(conf) for conf in configs]
env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())]
<<<<<<< HEAD
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{matrix_instr_nonkdim}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{arch}-{env_vars_list}"
=======
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}"
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
return hashlib.md5(key.encode("utf-8")).hexdigest()
assert isinstance(fn, str)
ignore_version = kwargs.get('ignore_version', False)
if (ignore_version):
return hashlib.md5((Path(fn).read_text()).encode("utf-8")).hexdigest()
return hashlib.md5((Path(fn).read_text() + version_key()).encode("utf-8")).hexdigest()
@@ -266,7 +353,8 @@ def make_hash(fn, arch, env_vars, **kwargs):
# (letters, digits, or underscores), and capture it as group 1 (the function name)
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
mlir_prototype_pattern = r'^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$'
# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3
mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$"
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
prototype_pattern = {
"ttir": mlir_prototype_pattern,
@@ -274,7 +362,11 @@ prototype_pattern = {
"ptx": ptx_prototype_pattern,
}
mlir_arg_type_pattern = r'%\w+: ([^,^\)\s]+)(?: \{\S+ = \S+ : \S+\})?,?'
# - ((?:[^,\s<]+|<[^>]+>)+): Capturing group that matches one or more of either:
# [^,\s<]+: One or more characters that are not a comma, whitespace, or the < symbol.
# |: OR
# <[^>]+>: A string that starts with < and ends with >, containing any characters except > in between.
mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<]+|<[^>]+>)+),?'
ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
arg_type_pattern = {
"ttir": mlir_arg_type_pattern,
@@ -282,7 +374,11 @@ arg_type_pattern = {
"ptx": ptx_arg_type_pattern,
}
if is_hip():
<<<<<<< HEAD
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
=======
ttgir_num_warps_pattern = r'"triton_gpu_rocm.num-warps"\s?=\s?(\d+)\s?:'
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
else:
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
@@ -311,6 +407,7 @@ def parse_mlir_module(path, context):
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"], defaults=[set(), set(), set(), set()])
<<<<<<< HEAD
# TODO: architecture descriptor class
def _is_cuda(arch):
return isinstance(arch, int)
@@ -337,6 +434,15 @@ def get_architecture_descriptor(capability):
capability = get_device_capability(device)
capability = capability[0] * 10 + capability[1]
return capability
=======
def get_cuda_capability(capability):
if capability is None:
device = get_current_device()
capability = get_device_capability(device)
capability = capability[0] * 10 + capability[1]
return capability
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
@functools.lru_cache
def get_arch_default_num_warps(device_type):
@@ -347,15 +453,19 @@ def get_arch_default_num_warps(device_type):
assert _device_backend
arch = _device_backend.get_architecture_descriptor()
num_warps = arch["num_warps"]
return num_warps
@functools.lru_cache
def get_arch_default_num_stages(device_type, capability=None):
<<<<<<< HEAD
if device_type in ["cuda"]:
arch = get_architecture_descriptor(capability)
is_cuda = device_type == "cuda" and _is_cuda(arch)
num_stages = 3 if is_cuda and arch >= 75 else 2
=======
if device_type == "cuda":
num_stages = 3 if get_cuda_capability(capability) >= 75 else 2
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
else:
_device_backend = get_backend(device_type)
assert _device_backend
@@ -365,11 +475,16 @@ def get_arch_default_num_stages(device_type, capability=None):
return num_stages
<<<<<<< HEAD
def add_cuda_stages(arch, extern_libs, stages):
=======
def add_cuda_stages(target, extern_libs, stages):
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
stages["ptx"] = (lambda path: Path(path).read_text(),
lambda src: llir_to_ptx(src, arch))
lambda src: llir_to_ptx(src, target))
stages["cubin"] = (lambda path: Path(path).read_bytes(),
lambda src: ptx_to_cubin(src, arch))
lambda src: ptx_to_cubin(src, target))
def compile(fn, **kwargs):
@@ -379,6 +494,7 @@ def compile(fn, **kwargs):
if is_hip():
device_type = "hip"
<<<<<<< HEAD
capability = None
if device_type == "cuda":
@@ -393,6 +509,12 @@ def compile(fn, **kwargs):
if is_hip():
is_cuda = False
warp_size = CUDA_DEFAULT_WARP_SIZE if _is_cuda(arch) else arch["warp_size"]
=======
is_cuda = device_type == "cuda"
if is_hip():
is_cuda = False
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
context = ir.context()
constants = kwargs.get("constants", dict())
num_warps = kwargs.get("num_warps", get_arch_default_num_warps(device_type))
@@ -420,11 +542,23 @@ def compile(fn, **kwargs):
cluster_info.clusterDimY = kwargs["clusterDims"][1]
cluster_info.clusterDimZ = kwargs["clusterDims"][2]
tma_infos = TMAInfos()
<<<<<<< HEAD
=======
# build architecture descriptor
if device_type == "cuda":
_device_backend = get_backend(device_type)
target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps)
else:
_device_backend = get_backend(device_type)
assert _device_backend
target = _device_backend.get_architecture_descriptor(**kwargs)
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
# build compilation stages
stages = dict()
stages["ast"] = (lambda path: fn, None)
stages["ttir"] = (lambda path: parse_mlir_module(path, context),
<<<<<<< HEAD
lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch))
if is_cuda:
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
@@ -456,12 +590,23 @@ def compile(fn, **kwargs):
stages["llir"] = (lambda path: Path(path).read_text(),
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos))
_device_backend.add_stages(arch, extern_libs, stages)
=======
lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target))
if is_cuda:
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
stages["llir"] = (lambda path: Path(path).read_text(),
lambda src: ttgir_to_llir(src, extern_libs, target, tma_infos))
add_cuda_stages(target, extern_libs, stages)
elif device_type == "hip":
_device_backend.add_stages(target, extern_libs, stages, num_warps=num_warps, num_stages=num_stages)
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
else:
# pass the user's configuration to the backend device.
arch["num_warps"] = num_warps
arch["num_stages"] = num_stages
arch["num_ctas"] = num_ctas
_device_backend.add_stages(arch, extern_libs, stages)
target["num_warps"] = num_warps
target["num_stages"] = num_stages
target["num_ctas"] = num_ctas
_device_backend.add_stages(target, extern_libs, stages)
# find out the signature of the function
if isinstance(fn, JITFunction):
@@ -482,6 +627,7 @@ def compile(fn, **kwargs):
src = Path(fn).read_text()
import re
match = re.search(prototype_pattern[ir_name], src, re.MULTILINE)
# TODO: support function attributes at group 3 (e.g., device function)
name, signature = match.group(1), match.group(2)
types = re.findall(arg_type_pattern[ir_name], signature)
if ir_name == 'ttgir':
@@ -494,7 +640,12 @@ def compile(fn, **kwargs):
first_stage = list(stages.keys()).index(ir_name)
# create cache manager
fn_cache_manager = get_cache_manager(make_hash(fn, arch, get_env_vars(), **kwargs))
fn_cache_manager = get_cache_manager(make_hash(fn, target, get_env_vars(), **kwargs))
# managers used to dump and override IR for debugging
enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1"
fn_override_manager = get_override_manager(make_hash(fn, target, get_env_vars(), **kwargs, ignore_version=True))
fn_dump_manager = get_dump_manager(make_hash(fn, target, get_env_vars(), **kwargs, ignore_version=True))
# determine name and extension type of provided function
if isinstance(fn, JITFunction):
name, ext = fn.__name__, "ast"
@@ -529,7 +680,7 @@ def compile(fn, **kwargs):
"enable_persistent": enable_persistent,
"constants": _get_jsonable_constants(constants),
"debug": debug,
"arch": arch, }
"target": target, }
metadata.update(get_env_vars())
if ext == "ptx":
assert "shared" in kwargs, "ptx compilation must provide shared memory size"
@@ -539,7 +690,7 @@ def compile(fn, **kwargs):
metadata["device_type"] = device_type
first_stage = list(stages.keys()).index(ext)
asm = dict()
asm = LazyDict()
module = fn
# run compilation pipeline and populate metadata
for ir_name, (parse, compile_kernel) in list(stages.items())[first_stage:]:
@@ -557,7 +708,11 @@ def compile(fn, **kwargs):
metadata_group[extra_file_name] = fn_cache_manager.put(next_module[1], extra_file_name)
else:
metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
fn_cache_manager.put(next_module, ir_filename)
fn_dump_manager.put(next_module, ir_filename)
if (enable_override and fn_override_manager.has_file(ir_filename)):
print(f"\nOverriding kernel with file {ir_filename}")
full_name = fn_override_manager.get_file(ir_filename)
next_module = parse(full_name)
else:
if ir_name == "amdgcn":
extra_file_name = f"{name}.hsaco_path"
@@ -569,6 +724,7 @@ def compile(fn, **kwargs):
if ir_name == "cubin":
asm[ir_name] = next_module
asm["sass"] = lambda: get_sass(next_module)
elif ir_name == "amdgcn":
asm[ir_name] = str(next_module[0])
else:
@@ -579,11 +735,19 @@ def compile(fn, **kwargs):
else:
metadata["shared"] = get_shared_memory_size(module)
if ir_name == "ttgir":
<<<<<<< HEAD
metadata["enable_warp_specialization"] = ir.is_ws_supported(next_module)
if metadata["enable_warp_specialization"]:
if is_hip():
metadata["num_warps"] = _device_backend.get_num_warps(next_module)
else:
=======
if is_hip():
metadata["num_warps"] = _device_backend.get_num_warps(next_module)
else:
metadata["enable_warp_specialization"] = ir.is_ws_supported(next_module)
if metadata["enable_warp_specialization"]:
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
metadata["num_warps"] = get_num_warps(next_module)
if ir_name == "ptx":
metadata["name"] = get_kernel_name(next_module, pattern='// .globl')
@@ -723,16 +887,3 @@ class CompiledKernel:
self.clusterDims[1], self.clusterDims[2], self.shared, stream, self.cu_function,
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand)
return runner
def get_sass(self, fun=None):
if 'sass' in self.asm:
return self.asm['sass']
fd, path = tempfile.mkstemp()
try:
with open(fd, 'wb') as cubin:
cubin.write(self.asm['cubin'])
self.sass = extract(path, fun)
finally:
os.remove(path)
self.asm['sass'] = self.sass
return self.sass

View File

@@ -63,8 +63,9 @@ def ty_to_cpp(ty):
def generate_launcher(constants, signature, ids):
start_desc = len(signature)
signature = generate_cu_signature(constants, signature, ids)
# Record the end of regular arguments;
# subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
signature, desc_start_idx = generate_cu_signature(constants, signature, ids)
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
def _extracted_type(ty):
@@ -99,7 +100,11 @@ def generate_launcher(constants, signature, ids):
# generate glue code
folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']]
<<<<<<< HEAD
params = [i for i in signature.keys() if i >= start_desc or (i not in constants and i not in folded_without_constexprs)]
=======
params = [i for i in signature.keys() if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs)]
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
src = f"""
#include \"cuda.h\"
#include <stdbool.h>
@@ -116,7 +121,10 @@ static inline void gpuAssert(CUresult code, const char *file, int line)
char err[1024] = {{0}};
strcat(err, prefix);
strcat(err, str);
PyGILState_STATE gil_state;
gil_state = PyGILState_Ensure();
PyErr_SetString(PyExc_RuntimeError, err);
PyGILState_Release(gil_state);
}}
}}
@@ -251,6 +259,9 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
Py_BEGIN_ALLOW_THREADS;
_launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''});
Py_END_ALLOW_THREADS;
if (PyErr_Occurred()) {{
return NULL;
}}
if (launch_exit_hook != Py_None && !PyObject_CallObject(launch_exit_hook, args)) {{
return NULL;

View File

@@ -26,12 +26,11 @@ from ..runtime import driver
def generate_cu_signature(constants, signature, ids):
# CUtensorMap*s are always the last arguments
num_regular_signatures = max(signature.keys()) + 1 if len(signature) > 0 else 0
if ids["ids_of_tensormaps"] is not None:
signature = signature.copy()
num_signature = len(signature)
for i, _ in enumerate(ids["ids_of_tensormaps"]):
signature[num_signature + i] = '*CUtensorMap'
return signature
signature[num_regular_signatures + i] = '*CUtensorMap'
return signature, num_regular_signatures
def dummy_tensormaps_info(n=2):

View File

@@ -1,9 +0,0 @@
from typing import Tuple
import dataclasses
@dataclasses.dataclass
class ExecutionContext:
program_id: Tuple[int]
program_size: Tuple[int]

View File

@@ -1,171 +0,0 @@
import itertools
import random
from typing import Tuple
from .. import language as tl
# import .language.core as lcore
from ..language import core as lcore
from . import torch_wrapper
from .core import ExecutionContext
from .memory_map import MemoryMap
from .tl_lang import (TritonLangProxy, WrappedTensor, _primitive_to_tensor,
debugger_constexpr)
torch = torch_wrapper.torch
tl_method_backup = {}
def get_proxy_method(proxy, name):
method = getattr(proxy, name)
def fun(*args, **kwarg):
return method(*args, **kwarg)
return fun
def attach_triton(module, proxy):
method_list = [func for func in dir(TritonLangProxy) if func[0] != "_"]
for name in method_list:
if hasattr(module, name):
attr = getattr(module, name)
tl_method_backup[name] = attr
if callable(attr):
setattr(module, name, get_proxy_method(proxy, name))
else:
setattr(module, name, getattr(proxy, name))
def detach_triton(module):
for name, method in tl_method_backup.items():
setattr(module, name, method)
def program_ids_from_grid(grid: Tuple[int, ...]) -> Tuple[int, ...]:
# reverse the grid dimensions and generate the range for each dimension
reversed_grid = reversed(grid)
ranges_for_each_dimension = [range(dim) for dim in reversed_grid]
# gen all combinations
index_combinations = list(itertools.product(*ranges_for_each_dimension))
random.shuffle(index_combinations)
for index_combination in index_combinations:
yield index_combination
class DebuggerFunction:
def __init__(self, func, grid=(1,)):
self.func = func
self.grid = grid
def _is_constexpr(self, name):
return name in self.func.__annotations__ and self.func.__annotations__[name] is lcore.constexpr
def _get_constexpr(self):
result = []
for name, annotation in self.func.__annotations__.items():
if annotation is lcore.constexpr:
result.append(name)
return result
def _assert_constexpr(self, **kwargs):
constexp = self._get_constexpr()
missing = [i for i in constexp if i not in kwargs.keys()]
assert len(missing) == 0, f"You must specify constexpr {missing}"
def _get_grid(self, **kwargs):
if callable(self.grid):
return self.grid(kwargs)
else:
return self.grid
def __call__(self, *args, **kwargs):
self._assert_constexpr(**kwargs)
memory = MemoryMap()
def convert_arg(v):
name, arg = v
if torch.is_tensor(arg):
ptr = memory.add_tensor(arg)
return WrappedTensor(torch.tensor([ptr], dtype=torch.int64, device="cuda"))
if self._is_constexpr(name):
return debugger_constexpr(arg)
return WrappedTensor(_primitive_to_tensor(arg))
new_args = tuple(map(convert_arg, zip(self.func.__code__.co_varnames, args)))
new_kwargs = {k: convert_arg((k, v)) for (k, v) in kwargs.items() if k not in ["num_warps", "num_stages"]}
grid = self._get_grid(**kwargs)
for program_id in program_ids_from_grid(grid):
proxy = TritonLangProxy(memory, ExecutionContext(program_id, grid))
attach_triton(tl, proxy)
self.func(*new_args, **new_kwargs)
detach_triton(tl)
class GridSelector:
"""
Entry point of the debugger
"""
def __init__(self, func):
version = torch.__version__
assert version[0] == "2", f"Triton Debugger only supports torch >= 2.0, using {version}"
self.func = func
def __getitem__(self, grid):
return DebuggerFunction(self.func, grid)
def __call__(self, *args, **kwargs):
return DebuggerFunction(self.func)(*args, **kwargs)
class AutotuneGridSelector:
def __init__(self, func, autotune_params):
self.func = func
self.autotune_params = autotune_params
def __getitem__(self, grid):
return AutotuneRunner(self.func, self.autotune_params, grid)
def __call__(self, *args, **kwargs):
return AutotuneRunner(self.func, self.autotune_params)(*args, **kwargs)
class AutotuneRunner:
def __init__(self, func, autotune_params, grid=None):
self.func = func
self.autotune_params = autotune_params
self.grid = grid
def __call__(self, *args, **kwargs):
assert len(self.autotune_params["configs"]) >= 1
for config in self.autotune_params["configs"][1:]:
def convert_arg(v):
if torch.is_tensor(v):
return torch.clone(v)
return v
new_args = tuple(map(convert_arg, args))
new_kwargs = {k: convert_arg(v) for k, v in kwargs.items()}
if self.grid:
self.func[self.grid](*new_args, **new_kwargs, **config.kwargs)
else:
self.func(*new_args, **new_kwargs, **config.kwargs)
main_config = self.autotune_params["configs"][0]
if self.grid:
self.func[self.grid](*args, **kwargs, **main_config.kwargs)
else:
self.func(*args, **kwargs, **main_config.kwargs)
def triton_debug_autotune(**kwars):
def wrapper(func):
return AutotuneGridSelector(func, kwars)
return wrapper

View File

@@ -1,102 +0,0 @@
from __future__ import annotations
import dataclasses
from . import torch_wrapper
torch = torch_wrapper.torch
@dataclasses.dataclass
class RegisteredStorage:
storage: torch.Storage
dtype: torch.dtype
size: int
ptr: int
@property
def end_ptr(self) -> int:
return self.ptr + self.size
@property
def access_tensor(self) -> torch.Tensor:
return torch.tensor(self.storage, dtype=self.dtype, device=self.storage.device)
def ensure_immutable(self):
assert self.storage.data_ptr() == self.ptr and self.storage.size() == self.size
class MemoryMap:
storages: [RegisteredStorage]
def __init__(self):
self.storages = []
def _get_registered_storage(self, pointer: torch.Tensor):
max_pointer = torch.max(pointer).item()
min_pointer = torch.min(pointer).item()
registered_storage = next(
filter(
lambda registered: min_pointer >= registered.ptr and max_pointer < registered.end_ptr, self.storages
),
None,
)
if registered_storage is None:
raise Exception("Storage not found or pointers spanning multiple tensors")
registered_storage.ensure_immutable()
return registered_storage
def add_tensor(self, t: torch.Tensor):
storage = t.untyped_storage()
self.storages.append(RegisteredStorage(storage, t.dtype, storage.size(), storage.data_ptr()))
return t.data_ptr()
def load(
self,
pointer: torch.Tensor,
mask: torch.Tensor = None,
other=0.0,
):
assert pointer.is_cuda
assert 0 < pointer.dim() < 3
assert pointer.dtype == torch.int64
if mask is None:
mask = torch.ones_like(pointer).bool()
assert mask.is_cuda
assert 0 < mask.dim() < 3
assert mask.dtype == torch.bool
mask = mask.expand(pointer.size())
if torch.all(~mask):
# Todo: The type is wrong here, we can't determine the correct type
return torch.full_like(pointer, fill_value=other, dtype=torch.float16, device="cuda")
registered_storage = self._get_registered_storage(pointer[mask])
access_tensor = registered_storage.access_tensor
index_tensor = pointer - registered_storage.ptr
block = torch.full_like(pointer, fill_value=other, dtype=access_tensor.dtype, device="cuda")
block[mask] = access_tensor[index_tensor[mask]]
return block
def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None):
assert 0 < pointer.dim() < 3
assert pointer.dtype == torch.int64
if mask is None:
mask = torch.ones_like(pointer).bool()
assert 0 < mask.dim() < 3
assert mask.dtype == torch.bool
mask = mask.expand(pointer.size())
if torch.all(~mask):
return
registered_storage = self._get_registered_storage(pointer[mask])
access_tensor = registered_storage.access_tensor
index_tensor = pointer - registered_storage.ptr
access_tensor[index_tensor[mask]] = value[mask].to(access_tensor.dtype)

View File

@@ -1,641 +0,0 @@
from __future__ import annotations
from ..language import core as lcore
from . import torch_wrapper
from .core import ExecutionContext
from .memory_map import MemoryMap
torch = torch_wrapper.torch
def _primitive_to_tensor(x):
"""
Converts various Python primitive data types to PyTorch tensor.
"""
tensor_args = {"device": "cuda"}
if isinstance(x, bool):
return torch.tensor([x], dtype=torch.bool, **tensor_args)
elif isinstance(x, int):
if -(2**31) <= x < 2**31:
return torch.tensor([x], dtype=torch.int32, **tensor_args)
elif -(2**63) <= x < 2**63:
return torch.tensor([x], dtype=torch.int64, **tensor_args)
else:
raise RuntimeError(f"Nonrepresentable integer {x}.")
elif isinstance(x, float):
return torch.tensor([x], dtype=torch.float32, **tensor_args)
elif torch.is_tensor(x):
return x
elif isinstance(x, WrappedTensor):
return x
elif isinstance(x, debugger_constexpr):
if x.value is None:
return None
return _primitive_to_tensor(x.value)
elif x is None:
return None
assert False, f"cannot convert {x} of type {type(x)} to tensor"
def _infer_tensor(func):
"""
A decorator function to harmonize function args:
- converts primitives to PyTorch tensors
- wraps PyTorch tensors with WrappedTensors
"""
def wrapper(*args):
new_args = tuple(map(lambda v: _primitive_to_tensor(v), args))
new_args = tuple(map(lambda v: WrappedTensor(v) if torch.is_tensor(v) else v, new_args))
return func(*new_args)
return wrapper
def _tensor_operation(func):
"""
A decorator function to unwrap WrappedTensors and debugger_constexpr before calling the function.
Can be combined with _infer_tensor decorator to harmonize args (everything to torch tensor).
"""
def wrapper(*args, **kwargs):
for arg in args:
assert not torch.is_tensor(arg), "unexpected tensor argument"
def unwrap_tensor(v):
if isinstance(v, WrappedTensor):
return v.tensor
if isinstance(v, debugger_constexpr):
return v.value
return v
new_args = tuple(map(unwrap_tensor, args))
new_kwargs = {k: unwrap_tensor(v) for k, v in kwargs.items()}
result = func(args[0], *new_args[1:], **new_kwargs)
return WrappedTensor(result) if torch.is_tensor(result) else result
return wrapper
class debugger_constexpr:
def __init__(self, value):
if isinstance(value, debugger_constexpr):
self.value = value.value
else:
self.value = value
def __str__(self) -> str:
return "debugger_constexpr(" + str(self.value) + ")"
def __index__(self) -> int:
return self.value
def __bool__(self):
return bool(self.value)
def __ge__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value >= other
def __gt__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value > other
def __le__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value <= other
def __lt__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value < other
def __eq__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value == other
def __or__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value | other
def __ror__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value | other
def __and__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value & other
def __rand__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value & other
def to(self, dtype, bitcast=False, _builder=None):
if dtype in [torch.int64]:
ret_ty = int
elif dtype == torch.bool:
ret_ty = bool
elif dtype in [torch.float64]:
ret_ty = float
else:
raise ValueError("dtype not supported in debugger")
return debugger_constexpr(ret_ty(self.value))
class WrappedTensor:
def __init__(self, tensor):
self.tensor = tensor
def __index__(self) -> int:
return self.tensor.item()
def __str__(self) -> str:
return "wrapped_" + str(self.tensor)
def __bool__(self) -> bool:
return torch.all(self.tensor == True).item() # noqa: E712
@property
def dtype(self):
return self.tensor.dtype
@_infer_tensor
@_tensor_operation
def __add__(self, other):
return torch.add(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __radd__(self, other):
return self.__add__(other)
@_infer_tensor
@_tensor_operation
def __sub__(self, other):
return torch.sub(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rsub__(self, other):
return torch.sub(other, self.tensor)
@_infer_tensor
@_tensor_operation
def __mul__(self, other):
return torch.mul(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rmul__(self, other):
return self.__mul__(other)
@_infer_tensor
@_tensor_operation
def __truediv__(self, other):
return torch.div(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rtruediv__(self, other):
return torch.div(other, self.tensor)
@_infer_tensor
@_tensor_operation
def __floordiv__(self, other):
return torch.floor_divide(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rfloordiv__(self, other):
return torch.floor_divide(other, self.tensor)
@_infer_tensor
@_tensor_operation
def __mod__(self, other):
return torch.remainder(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rmod__(self, other):
return torch.remainder(other, self.tensor)
@_infer_tensor
@_tensor_operation
def __neg__(self):
return -self.tensor
@_infer_tensor
@_tensor_operation
def __invert__(self):
return ~self.tensor
@_infer_tensor
@_tensor_operation
def __and__(self, other):
return torch.bitwise_and(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __or__(self, other):
return torch.bitwise_or(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __xor__(self, other):
return torch.bitwise_xor(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __lshift__(self, other):
return torch.bitwise_left_shift(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rshift__(self, other):
return torch.bitwise_right_shift(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __gt__(self, other):
return self.tensor > other
@_infer_tensor
@_tensor_operation
def __rgt__(self, other):
return other > self.tensor
@_infer_tensor
@_tensor_operation
def __ge__(self, other):
return self.tensor >= other
@_infer_tensor
@_tensor_operation
def __rge__(self, other):
return other >= self.tensor
@_infer_tensor
@_tensor_operation
def __lt__(self, other):
return self.tensor < other
@_infer_tensor
@_tensor_operation
def __rlt__(self, other):
return other < self.tensor
@_infer_tensor
@_tensor_operation
def __le__(self, other):
return self.tensor <= other
@_infer_tensor
@_tensor_operation
def __rle__(self, other):
return other <= self.tensor
@_infer_tensor
@_tensor_operation
def __eq__(self, other):
return torch.equal(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __ne__(self, other):
return not torch.equal(self.tensor, other)
@_tensor_operation
def __getitem__(self, slices):
return self.tensor.__getitem__(slices)
# if isinstance(slices, slice):
# slices = [slices]
# src_shape = self.shape
# dst_shape = []
# curr = 0
# for sl in slices:
# if isinstance(sl, constexpr) and sl.value is None:
# dst_shape.append(1)
# elif sl == slice(None, None, None):
# dst_shape.append(src_shape[curr].value)
# curr += 1
# ret = torch.reshape(self.tensor, dst_shape, )
# return ret
@_tensor_operation
def to(self, dtype, bitcast=False):
return self.tensor.to(dtype)
# if isinstance(bitcast, constexpr):
# bitcast = bitcast.value
# if bitcast:
# return semantic.bitcast(self, dtype, )
# return semantic.cast(self, dtype, )
def _constexpr_to_value(v):
if isinstance(v, debugger_constexpr):
return v.value
return v
class TritonLangProxy:
_memory_map: MemoryMap
_context: ExecutionContext
def __init__(self, memory_map: MemoryMap, context: ExecutionContext):
self._memory_map = memory_map
self._context = context
# Types
# Removed void, int1, float8, uint16, uint32, uint64, pi32_t
# constexpr = debugger_constexpr
# Program functions
@_tensor_operation
def load(
self,
pointer: torch.Tensor,
mask: torch.Tensor = None,
other=0.0,
cache_modifier="",
eviction_policy="",
volatile=False,
):
return self._memory_map.load(pointer, mask, other)
@_tensor_operation
def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None):
return self._memory_map.store(pointer, value, mask)
@_tensor_operation
def program_id(self, axis):
assert axis < len(self._context.program_id)
return torch.tensor([self._context.program_id[axis]], dtype=torch.int32, device="cuda")
@_tensor_operation
def num_programs(self, axis):
assert axis < len(self._context.program_size)
return torch.tensor([self._context.program_size[axis]], dtype=torch.int32, device="cuda")
@_tensor_operation
def arange(self, start, end):
return torch.arange(start=start, end=end, dtype=torch.int32, device="cuda")
@_tensor_operation
def zeros(self, shape, dtype):
for i, d in enumerate(shape):
if not isinstance(d, debugger_constexpr):
raise TypeError(f"Shape element {i} must have type `constexpr`")
if not isinstance(d.value, int):
raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
shape = [x.value for x in shape]
if isinstance(dtype, lcore.dtype):
if dtype.is_fp32():
dtype = torch.float32
elif dtype.is_fp16():
dtype = torch.float16
elif dtype.is_bf16():
dtype = torch.bfloat16
elif dtype.is_int32():
dtype = torch.int32
elif dtype.is_int16():
dtype = torch.int16
elif dtype.is_int8():
dtype = torch.int8
else:
raise TypeError(f"Unsupported dtype {dtype}")
return torch.zeros(size=shape, dtype=dtype, device="cuda")
@_tensor_operation
def dequantize(self, input, scale, shift, nbit, dst_ty=None):
if dst_ty is None:
dst_ty = torch.float16
raise NotImplementedError()
@_tensor_operation
def broadcast(self, input, other):
raise NotImplementedError()
@_tensor_operation
def broadcast_to(self, input, shape):
raise NotImplementedError()
@_tensor_operation
def cat(self, input, shape):
raise NotImplementedError()
@_tensor_operation
def reshape(self, input, shape):
raise NotImplementedError()
@_tensor_operation
def dot(self, input, other, trans_a=False, trans_b=False, allow_tf32=True):
assert input.dtype == other.dtype
if trans_a:
input = input.T
if trans_b:
other = other.T
return torch.matmul(input=input, other=other)
@_tensor_operation
def atomic_cas(self, pointer, cmp, val):
stored = self._memory_map.load(pointer, None, 0.0)
if not isinstance(cmp, torch.Tensor):
cmp = torch.tensor([cmp], dtype=stored.dtype, device="cuda")
if not isinstance(val, torch.Tensor):
val = torch.tensor([val], dtype=stored.dtype, device="cuda")
if stored == cmp:
self._memory_map.store(pointer, val, None)
return stored
@_tensor_operation
def atomic_xchg(self, pointer, val, mask=None):
if isinstance(val, int):
val = torch.tensor([val], dtype=torch.int32, device="cuda")
stored = self._memory_map.load(pointer, mask, 0.0)
self._memory_map.store(pointer, val, mask)
return stored
@_tensor_operation
def atomic_add(self, pointer, val, mask=None):
# arbitrary other value as it will masked during storing
stored = self._memory_map.load(pointer, mask, 0.0)
result = stored + val
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def atomic_max(self, pointer, val, mask=None):
stored = self._memory_map.load(pointer, mask, 0.0)
result = torch.maximum(stored, val)
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def atomic_min(self, pointer, val, mask=None):
stored = self._memory_map.load(pointer, mask, 0.0)
result = torch.minimum(stored, val)
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def atomic_and(self, pointer, val, mask=None):
stored = self._memory_map.load(pointer, mask, 0)
result = torch.bitwise_and(stored, val)
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def atomic_or(self, pointer, val, mask=None):
stored = self._memory_map.load(pointer, mask, 0)
result = torch.bitwise_or(stored, val)
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def atomic_xor(self, pointer, val, mask=None):
stored = self._memory_map.load(pointer, mask, 0)
result = torch.bitwise_xor(stored, val)
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def where(self, condition, x, y):
condition = _primitive_to_tensor(condition)
x = _primitive_to_tensor(x)
y = _primitive_to_tensor(y)
return torch.where(condition, x, y)
@_tensor_operation
def umulhi(self, x, y):
raise NotImplementedError()
@_tensor_operation
def fdiv(self, x, y, ieee_rounding=False):
raise NotImplementedError()
@_tensor_operation
def exp(self, x):
return torch.exp(x)
@_tensor_operation
def log(self, x):
return torch.log(x)
@_tensor_operation
def cos(self, x):
return torch.cos(x)
@_tensor_operation
def sin(self, x):
return torch.sin(x)
@_tensor_operation
def sqrt(self, x):
return torch.sqrt(x)
@_tensor_operation
def globaltimer(self):
raise NotImplementedError()
@_tensor_operation
def clock(self):
raise NotImplementedError()
@_tensor_operation
def debug_barrier(self):
raise NotImplementedError()
@_tensor_operation
def multiple_of(self, input, values):
return input
@_tensor_operation
def max_contiguous(self, input, values):
return input
@_tensor_operation
def max_constancy(self, input, values):
return input
@_tensor_operation
def abs(self, x):
return torch.abs(x)
@_tensor_operation
def cdiv(self, x, div):
return (x + div - 1) // div
@_tensor_operation
def minimum(self, x, y):
if isinstance(x, int):
x = torch.tensor(x, device="cuda")
if isinstance(y, int):
y = torch.tensor(y, device="cuda")
return torch.minimum(x, y)
@_tensor_operation
def maximum(self, x, y):
return torch.maximum(x, y)
@_tensor_operation
def sigmoid(self, x):
raise NotImplementedError()
@_tensor_operation
def softmax(self, x, ieee_rounding=False):
raise NotImplementedError()
@_tensor_operation
def ravel(self, x):
raise NotImplementedError()
@_tensor_operation
def swizzle2d(self, i, j, size_i, size_j, size_g):
raise NotImplementedError()
@_tensor_operation
def zeros_like(self, input):
raise NotImplementedError()
@_tensor_operation
def max(self, input, axis=None):
if axis is None:
return torch.max(input)
return torch.max(input, dim=axis).values
@_tensor_operation
def argmax(self, input, axis):
raise NotImplementedError()
@_tensor_operation
def min(self, input, axis=None):
if axis is None:
return torch.min(input)
return torch.min(input, dim=axis).values
@_tensor_operation
def argmin(self, input, axis):
raise NotImplementedError()
@_tensor_operation
def sum(self, input, axis=None):
if axis is None:
return torch.sum(input)
return torch.sum(input, dim=axis)
@_tensor_operation
def xor_sum(self, input, axis):
raise NotImplementedError()
@_tensor_operation
def cumsum(self, input, axis=None):
if axis is None:
return torch.cumsum(input)
return torch.cumsum(input, dim=axis)
@_tensor_operation
def cumprod(self, input, axis=None):
if axis is None:
return torch.cumprod(input)
return torch.cumprod(input, dim=axis)

View File

@@ -1,18 +0,0 @@
try:
import torch as _torch
except ImportError:
_torch = None
class TorchWrapper:
"""
Helps in making torch an optional dependency
"""
def __getattr__(self, name):
if _torch is None:
raise ImportError("Triton requires PyTorch to be installed")
return getattr(_torch, name)
torch = TorchWrapper()

View File

@@ -293,7 +293,7 @@ class dtype:
return self.name
def __repr__(self):
return f'triton.language.{self.name}'
return f'triton.language.{str(self)}'
class pointer_type(dtype):
@@ -551,9 +551,7 @@ class tensor:
# IR handle
self.handle = handle
# Block shape
self.shape = (1, )
if type.is_block():
self.shape = type.shape
self.shape = type.shape if type.is_block() else ()
self.numel = 1
for s in self.shape:
self.numel *= s
@@ -564,14 +562,15 @@ class tensor:
self.shape = [constexpr(s) for s in self.shape]
def __str__(self) -> str:
# ex. "float32[3,4]"
return str(self.dtype) + '[' + ','.join(str(s) for s in self.shape) + ']'
# ex. "float32[16, 32]"
return str(self.dtype) + '[' + ', '.join(str(s) for s in self.shape) + ']'
@builtin
def __add__(self, other, _builder=None):
other = _to_tensor(other, _builder)
return semantic.add(self, other, _builder)
@builtin
def __radd__(self, other, _builder=None):
return self.__add__(other, _builder=_builder)
@@ -580,6 +579,7 @@ class tensor:
other = _to_tensor(other, _builder)
return semantic.sub(self, other, _builder)
@builtin
def __rsub__(self, other, _builder=None):
other = _to_tensor(other, _builder)
return semantic.sub(other, self, _builder)
@@ -589,6 +589,7 @@ class tensor:
other = _to_tensor(other, _builder)
return semantic.mul(self, other, _builder)
@builtin
def __rmul__(self, other, _builder=None):
return self.__mul__(other, _builder=_builder)
@@ -597,6 +598,7 @@ class tensor:
other = _to_tensor(other, _builder)
return semantic.truediv(self, other, _builder)
@builtin
def __rtruediv__(self, other, _builder=None):
other = _to_tensor(other, _builder)
return semantic.truediv(other, self, _builder)
@@ -688,8 +690,6 @@ class tensor:
else:
return semantic.lshr(other, self, _builder)
# comparison operators
# >
@builtin
def __gt__(self, other, _builder=None):
@@ -763,11 +763,11 @@ class tensor:
@builtin
def __getitem__(self, slices, _builder=None):
if isinstance(slices, slice):
if isinstance(slices, (slice, constexpr)):
slices = [slices]
ret = self
for dim, sl in enumerate(slices):
if isinstance(sl, constexpr) and sl.value is None:
if sl is None or isinstance(sl, constexpr) and sl.value is None:
ret = semantic.expand_dims(ret, dim, _builder)
elif isinstance(sl, slice) and sl.start is None and sl.stop is None and sl.step is None:
pass
@@ -852,6 +852,8 @@ def arange(start, end, _builder=None):
def _shape_check_impl(shape):
shape = _constexpr_to_value(shape)
for i, d in enumerate(shape):
if isinstance(d, int):
d = constexpr(d)
if not isinstance(d, constexpr):
raise TypeError(f"Shape element {i} must have type `constexpr`")
if not isinstance(d.value, int):
@@ -930,6 +932,12 @@ def broadcast_to(input, shape, _builder=None):
@builtin
def trans(input, _builder=None):
"""
Returns a transposed tensor.
:param input: The input tensor.
:type input:
"""
return semantic.trans(input, _builder)
@@ -968,6 +976,15 @@ def view(input, shape, _builder=None):
@builtin
def reshape(input, shape, _builder=None):
"""
Returns a tensor with the same number of elements as input but with the
provided shape.
:param input: The input tensor.
:type input:
:param shape: The new shape.
:type shape: Tuple[int]
"""
shape = _shape_check_impl(shape)
return semantic.reshape(input, shape, _builder)
@@ -1012,7 +1029,7 @@ def expand_dims(input, axis, _builder=None):
@builtin
def dot(input, other, allow_tf32=True, out_dtype=float32, _builder=None):
def dot(input, other, acc=None, allow_tf32=True, max_num_imprecise_acc=None, out_dtype=float32, _builder=None):
"""
Returns the matrix product of two blocks.
@@ -1025,7 +1042,8 @@ def dot(input, other, allow_tf32=True, out_dtype=float32, _builder=None):
"""
allow_tf32 = _constexpr_to_value(allow_tf32)
out_dtype = _constexpr_to_value(out_dtype)
return semantic.dot(input, other, allow_tf32, out_dtype, _builder)
max_num_imprecise_acc = _constexpr_to_value(max_num_imprecise_acc)
return semantic.dot(input, other, acc, allow_tf32, max_num_imprecise_acc, out_dtype, _builder)
# -----------------------
@@ -1266,6 +1284,14 @@ def where(condition, x, y, _builder=None):
@builtin
def umulhi(x, y, _builder=None):
"""
Returns the most significant 32 bits of the product of x and y.
:param x: the input tensor
:type x: int32
:param y: the input tensor
:type y: int32
"""
x = _to_tensor(x, _builder)
y = _to_tensor(y, _builder)
return semantic.umulhi(x, y, _builder)
@@ -1273,6 +1299,15 @@ def umulhi(x, y, _builder=None):
@builtin
def fdiv(x, y, ieee_rounding=False, _builder=None):
"""
Returns a floating-point resultant tensor of dividing x by y.
:param x: the input numerator value.
:param y: the input denominator value.
:param ieee_rounding: To follow IEEE-754 floating point number
rounding mechanism
:type ieee_rounding: bool
"""
ieee_rounding = _constexpr_to_value(ieee_rounding)
return semantic.fdiv(x, y, ieee_rounding, _builder)

View File

@@ -13,3 +13,8 @@ def smid(_builder=None):
return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [],
dtype=core.int32, is_pure=True,
pack=1, _builder=_builder)
@core.builtin
def num_threads(_builder=None):
return core.constexpr(_builder.target.num_warps * 32)

View File

@@ -12,6 +12,13 @@ import re
T = TypeVar('T')
# TODO: redundant code -- remove after 3P backend refactor
def _is_cuda(target):
from ..compiler.compiler import CudaTargetDescriptor
return isinstance(target, CudaTargetDescriptor)
# Create custom exception that prints message "hello"
@@ -28,10 +35,14 @@ class IncompatibleTypeErrorImpl(Exception):
# ===----------------------------------------------------------------------===##
def program_id(axis: int, builder: ir.builder) -> tl.tensor:
if axis not in (0, 1, 2):
raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}")
return tl.tensor(builder.create_get_program_id(axis), tl.int32)
def num_programs(axis: int, builder: ir.builder) -> tl.tensor:
if axis not in (0, 1, 2):
raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}")
return tl.tensor(builder.create_get_num_programs(axis), tl.int32)
# ===----------------------------------------------------------------------===//
@@ -131,6 +142,8 @@ def add(input: tl.tensor,
input, other = binary_op_type_checking_impl(input, other, builder, True, True)
input_scalar_ty = input.type.scalar
other_scalar_ty = other.type.scalar
if input_scalar_ty.is_ptr() and other_scalar_ty.is_ptr():
raise ValueError("cannot add pointers together")
# offset + ptr
# ptr + offset
@@ -504,19 +517,18 @@ def full(shape: List[int], value, dtype: tl.dtype, builder: ir.builder) -> tl.te
if isinstance(value, tl.tensor):
assert value.numel.value == 1, "only accepts size-1 tensor"
value = cast(value, dtype, builder)
ret_ty = tl.block_type(value.dtype, shape)
return tl.tensor(builder.create_splat(value.handle, shape), ret_ty)
else:
# scalar
if dtype is None:
raise ValueError("dtype must be specified when value is not a tensor")
if value == 0:
value = builder.get_null_value(dtype.to_ir(builder))
else:
get_value_fn = getattr(builder, f"get_{dtype.name}")
value = get_value_fn(value)
if dtype is None:
raise ValueError("dtype must be specified when value is not a tensor")
ret_ty = tl.block_type(dtype, shape)
return tl.tensor(builder.create_splat(value, shape), ret_ty)
value = tl.tensor(value, dtype)
return splat(value, shape, builder)
@@ -529,6 +541,13 @@ def ones(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
# Shape Manipulation
# ===----------------------------------------------------------------------===//
def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor:
assert not value.type.is_block(), "Cannot splat a block tensor"
if len(shape) == 0:
return value
ret_ty = tl.block_type(value.dtype, shape)
return tl.tensor(builder.create_splat(value.handle, shape), ret_ty)
def view(input: tl.tensor,
dst_shape: List[int],
@@ -553,8 +572,12 @@ def reshape(input: tl.tensor,
def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
dst_shape = list(input.type.shape)
dst_shape = [tl._constexpr_to_value(x) for x in input.shape]
dst_shape.insert(axis, 1)
if not input.type.is_block():
return splat(input, shape=dst_shape, builder=builder)
ret_ty = tl.block_type(input.type.scalar, dst_shape)
return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty)
@@ -674,11 +697,6 @@ def bitcast(input: tl.tensor,
dst_ty)
# TODO: architecture descriptor class
def _is_cuda(arch):
return isinstance(arch, int)
def cast(input: tl.tensor,
dst_ty: tl.dtype,
builder: ir.builder) -> tl.tensor:
@@ -693,7 +711,7 @@ def cast(input: tl.tensor,
src_sca_ty = src_ty.scalar
dst_sca_ty = dst_ty.scalar
if _is_cuda(builder.arch) and builder.arch < 89 and \
if _is_cuda(builder.target) and builder.target.capability < 89 and \
(src_sca_ty.is_fp8e4nv() or dst_sca_ty.is_fp8e4nv()):
assert False, "fp8e4nv data type is not supported on CUDA arch < 89"
@@ -1139,13 +1157,20 @@ def atomic_max(ptr: tl.tensor,
# for float
# return atomic_smax(i_ptr, i_val) if val >= 0
# return atomic_umin(i_ptr, i_val) if val < 0
i_val = bitcast(val, tl.int32, builder)
i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder)
pos = greater_equal(val, tl.tensor(builder.get_fp32(0), sca_ty), builder)
neg = less_than(val, tl.tensor(builder.get_fp32(0), sca_ty), builder)
if sca_ty not in {tl.float32, tl.float64}:
raise TypeError(f"atomic_max not supported for dtype {sca_ty}")
itype = tl.int32 if sca_ty == tl.float32 else tl.float64
zero = full([], 0.0, sca_ty, builder)
i_val = bitcast(val, itype, builder)
i_ptr = bitcast(ptr, tl.pointer_type(itype, 1), builder)
pos = greater_equal(val, zero, builder)
neg = less_than(val, zero, builder)
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle, sem), i_val.type)
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle, sem), i_val.type)
return where(pos, pos_ret, neg_ret, builder)
ret = where(pos, pos_ret, neg_ret, builder)
return bitcast(ret, sca_ty, builder)
def atomic_min(ptr: tl.tensor,
@@ -1175,10 +1200,16 @@ def atomic_min(ptr: tl.tensor,
# for float
# return atomic_smin(i_ptr, i_val) if val >= 0
# return atomic_umax(i_ptr, i_val) if val < 0
i_val = bitcast(val, tl.int32, builder)
i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder)
pos = greater_equal(val, tl.tensor(builder.get_fp32(0), sca_ty), builder)
neg = less_than(val, tl.tensor(builder.get_fp32(0), sca_ty), builder)
if sca_ty not in {tl.float32, tl.float64}:
raise TypeError(f"atomic_min not supported for dtype {sca_ty}")
itype = tl.int32 if sca_ty == tl.float32 else tl.float64
zero = full([], 0.0, sca_ty, builder)
i_val = bitcast(val, itype, builder)
i_ptr = bitcast(ptr, tl.pointer_type(itype, 1), builder)
pos = greater_equal(val, zero, builder)
neg = less_than(val, zero, builder)
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN,
i_ptr.handle,
i_val.handle,
@@ -1191,7 +1222,8 @@ def atomic_min(ptr: tl.tensor,
and_(mask, neg, builder).handle,
sem),
i_val.type)
return where(pos, pos_ret, neg_ret, builder)
ret = where(pos, pos_ret, neg_ret, builder)
return bitcast(ret, sca_ty, builder)
def atomic_add(ptr: tl.tensor,
@@ -1302,11 +1334,27 @@ def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool:
return False
return True
def gpu_has_mfma() -> bool:
if not is_hip():
return False
return True # mfma supported in ['gfx908', 'gfx90a']
def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool:
if not gpu_has_mfma():
return False
# TODO: Add check for configurations and types.
return True
def dot(lhs: tl.tensor,
rhs: tl.tensor,
acc: tl.tensor,
allow_tf32: bool,
max_num_imprecise_acc: int,
out_dtype: tl.dtype,
builder: ir.builder) -> tl.tensor:
<<<<<<< HEAD
def assert_dtypes_valid(lhs_dtype, rhs_dtype, arch):
if is_hip():
assert lhs.dtype == rhs.dtype or (lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp16()) or \
@@ -1320,6 +1368,31 @@ def dot(lhs: tl.tensor,
# Checks for cuda arch
if arch < 90:
assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv(), "Dot op does not support fp8e4nv on CUDA arch < 90"
=======
def assert_dtypes_valid(lhs_dtype, rhs_dtype, target):
# Checks for non-cuda archs
if not _is_cuda(target):
assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!"
return
# Checks for cuda arch
if target.capability < 90:
assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv(), "Dot op does not support fp8e4nv on CUDA arch < 90"
if lhs_dtype.is_fp8() and rhs_dtype.is_fp8():
return
assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!"
else:
assert not lhs_dtype.is_fp8e4b15() and not rhs_dtype.is_fp8e4b15(), "Dot op does not support fp8e4b15 on CUDA arch >= 90"
assert not lhs_dtype.is_fp8e4b15x4() and not rhs_dtype.is_fp8e4b15x4(), "Dot op does not support fp8e4b15x4 on CUDA arch >= 90"
if lhs_dtype.is_int() or rhs_dtype.is_int():
assert lhs_dtype == rhs_dtype, f"Both operands must be same type. First operand ({lhs_dtype}) and second operand ({rhs_dtype})"
assert lhs_dtype.is_int8() or lhs_dtype.is_uint8(), f"Both operands must be either int8 or uint8. Operand type ({lhs_dtype})"
elif lhs_dtype.is_fp8() or rhs_dtype.is_fp8():
assert lhs_dtype.is_fp8e4nv() or lhs_dtype.is_fp8e5(), f"Only supports fp8e4nv or fp8e5. First operand ({lhs_dtype})"
assert rhs_dtype.is_fp8e4nv() or rhs_dtype.is_fp8e5(), f"Only supports fp8e4nv or fp8e5. Second operand ({rhs_dtype})"
else:
assert lhs_dtype.is_fp16() or lhs_dtype.is_bf16() or lhs_dtype.is_fp32() or lhs_dtype.is_int1(), f"Unsupported dtype {lhs_dtype}"
assert rhs_dtype.is_fp16() or rhs_dtype.is_bf16() or rhs_dtype.is_fp32() or rhs_dtype.is_int1(), f"Unsupported dtype {rhs_dtype}"
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!"
else:
assert not lhs_dtype.is_fp8e4b15() and not rhs_dtype.is_fp8e4b15(), "Dot op does not support fp8e4b15 on CUDA arch >= 90"
@@ -1339,8 +1412,12 @@ def dot(lhs: tl.tensor,
assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!"
return
<<<<<<< HEAD
assert lhs.type.is_block() and rhs.type.is_block()
assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.arch)
=======
assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.target)
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
assert len(lhs.shape) == 2, f"First input shape ({lhs.shape}) is not two dimensional!"
assert len(rhs.shape) == 2, f"Second input shape ({rhs.shape}) is not two dimensional!"
@@ -1367,6 +1444,8 @@ def dot(lhs: tl.tensor,
assert is_hip() or lhs.shape[1].value >= 32, "small blocks not supported!"
_0 = builder.get_int32(0)
ret_scalar_ty = tl.int32
elif out_dtype.is_bf16():
raise ValueError("out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`")
elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16():
_0 = builder.get_fp32(0)
ret_scalar_ty = tl.float32
@@ -1401,10 +1480,25 @@ def dot(lhs: tl.tensor,
ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
ret_ty)
return cast(ret, ret_scalar_ty, builder)
<<<<<<< HEAD
_0 = builder.create_splat(_0, [M, N])
=======
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
ret_ty = tl.block_type(ret_scalar_ty, [M, N])
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
if acc is None:
acc_handle = builder.create_splat(_0, [M, N])
else:
acc_handle = acc.handle
assert acc.type == ret_ty
# max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90
if not (_is_cuda(builder.target) and builder.target.capability == 90 and lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and ret_scalar_ty.is_fp32()):
max_num_imprecise_acc = 0
if max_num_imprecise_acc is None:
max_num_imprecise_acc = 2**30
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, allow_tf32, max_num_imprecise_acc),
ret_ty)
@@ -1574,7 +1668,7 @@ def abs(x: tl.tensor, builder: ir.builder) -> tl.tensor:
def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:
if len(x.shape) != len(values):
if max(1, len(x.shape)) != len(values):
raise ValueError("Shape of input to multiple_of does not match the length of values")
x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context()))
return x
@@ -1614,6 +1708,8 @@ def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno:
def _convert_elem_to_ir_value(builder, elem, require_i64):
if isinstance(elem, int):
elem = tl.constexpr(elem)
if isinstance(elem, tl.constexpr):
return builder.get_int64(elem.value) if require_i64 else builder.get_int32(elem.value)
elif isinstance(elem, tl.tensor):

View File

@@ -160,7 +160,7 @@ def max(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr
else:
return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_fast)
else:
if core.constexpr(input.dtype.primitive_bitwidth) < 32:
if core.constexpr(input.dtype.primitive_bitwidth) < core.constexpr(32):
if core.constexpr(input.dtype.is_floating()):
input = input.to(core.float32)
else:

View File

@@ -21,9 +21,10 @@ def _fwd_kernel(
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_vz, stride_vh, stride_vn, stride_vk,
stride_oz, stride_oh, stride_om, stride_on,
Z, H, N_CTX,
Z_H_N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
IS_CAUSAL: tl.constexpr,
@@ -31,27 +32,21 @@ def _fwd_kernel(
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
qvk_offset = off_hz * stride_qh
Q_block_ptr = tl.make_block_ptr(
base=Q + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
vk_offset = qvk_offset // stride_qm
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
shape=(BLOCK_DMODEL, N_CTX),
base=K,
shape=(BLOCK_DMODEL, Z_H_N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
offsets=(0, vk_offset),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
base=V,
shape=(Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_vn, stride_vk),
offsets=(vk_offset, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
)
@@ -68,7 +63,11 @@ def _fwd_kernel(
# don't work as expected with `exp` in the loop
qk_scale = sm_scale * 1.44269504
# load q: it will stay in SRAM throughout
q = tl.load(Q_block_ptr)
offs_k = tl.arange(0, BLOCK_DMODEL)
Q_ptrs = Q + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk
q = tl.load(Q_ptrs)
q = (q * qk_scale).to(K.dtype.element_ty)
lo = 0
hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX
@@ -86,8 +85,7 @@ def _fwd_kernel(
alpha = tl.math.exp2(m_i - m_i_new)
p = tl.math.exp2(qk - m_i_new[:, None])
# -- scale and update acc --
acc_scale = l_i * 0 + alpha # workaround some compiler bug
acc *= acc_scale[:, None]
acc *= alpha[:, None]
acc += tl.dot(p.to(V.dtype.element_ty), v, allow_tf32=True)
# -- update m_i and l_i --
l_i = l_i * alpha + tl.sum(p, 1)
@@ -101,13 +99,14 @@ def _fwd_kernel(
tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
O_block_ptr = tl.make_block_ptr(
base=Out + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
base=Out,
shape=(Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
offsets=(vk_offset + start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
# O_ptrs = Out + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk
tl.store(O_block_ptr, acc.to(K.dtype.element_ty))
@@ -137,13 +136,14 @@ def _bwd_kernel_one_col_block(
D,
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_vz, stride_vh, stride_vn, stride_vk,
Z, H, N_CTX,
off_hz, start_n, num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr,
CAUSAL: tl.constexpr,
MMA_V3: tl.constexpr
):
if SEQUENCE_PARALLEL:
DQ += stride_dqa.to(tl.int64) * start_n
@@ -159,7 +159,7 @@ def _bwd_kernel_one_col_block(
# initialize pointers to value-like data
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
v_ptrs = V + (offs_n[:, None] * stride_vk + offs_k[None, :] * stride_vn)
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk)
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
# pointer to row-wise quantities in value-like data
@@ -203,8 +203,11 @@ def _bwd_kernel_one_col_block(
dq += tl.dot(ds, k, allow_tf32=True)
tl.store(dq_ptrs, dq)
elif SEQUENCE_PARALLEL:
# dq = tl.dot(ds, k, allow_tf32=True)
dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds), allow_tf32=True))
if MMA_V3:
dq = tl.dot(ds, k, allow_tf32=True)
else:
# not work with mma v3, becuase M % 64 != 0
dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds), allow_tf32=True))
tl.store(dq_ptrs, dq)
# increment pointers
@@ -212,7 +215,7 @@ def _bwd_kernel_one_col_block(
q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_qm
# write-back
dv_ptrs = DV + (offs_n[:, None] * stride_vk + offs_k[None, :] * stride_vn)
dv_ptrs = DV + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk)
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
@@ -228,12 +231,13 @@ def _bwd_kernel(
D,
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_vz, stride_vh, stride_vn, stride_vk,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr,
CAUSAL: tl.constexpr,
MMA_V3: tl.constexpr
# fmt: on
):
qk_scale = sm_scale * 1.44269504
@@ -259,13 +263,14 @@ def _bwd_kernel(
D,
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_vz, stride_vh, stride_vn, stride_vk,
Z, H, N_CTX,
off_hz, start_n, num_block_n,
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_N=BLOCK_N,
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
CAUSAL=CAUSAL,
MMA_V3=MMA_V3
)
else:
start_n = tl.program_id(1)
@@ -276,13 +281,14 @@ def _bwd_kernel(
D,
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_vz, stride_vh, stride_vn, stride_vk,
Z, H, N_CTX,
off_hz, start_n, num_block_n,
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_N=BLOCK_N,
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
CAUSAL=CAUSAL,
MMA_V3=MMA_V3
)
@@ -317,6 +323,7 @@ class _attention(torch.autograd.Function):
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
q.shape[0] * q.shape[1] * q.shape[2],
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
IS_CAUSAL=causal,
num_warps=num_warps,
@@ -332,6 +339,8 @@ class _attention(torch.autograd.Function):
@staticmethod
def backward(ctx, do):
capability = torch.cuda.get_device_capability()
MMA_V3 = capability[0] >= 9
BLOCK = 128
q, k, v, o, L = ctx.saved_tensors
sequence_parallel = ctx.sequence_parallel
@@ -365,6 +374,7 @@ class _attention(torch.autograd.Function):
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
SEQUENCE_PARALLEL=sequence_parallel,
CAUSAL=ctx.causal,
MMA_V3=MMA_V3,
num_warps=8,
num_stages=1,
)

View File

@@ -82,6 +82,7 @@ def _kernel(A, B, C, M, N, K,
stride_cm, stride_cn,
dot_out_dtype: tl.constexpr,
allow_tf32: tl.constexpr,
fp8_fast_accum: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr
):
@@ -118,7 +119,10 @@ def _kernel(A, B, C, M, N, K,
if AB_DTYPE:
a = a.to(C.dtype.element_ty)
b = b.to(C.dtype.element_ty)
acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
if fp8_fast_accum:
acc = tl.dot(a, b, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
else:
acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
acc = acc.to(C.dtype.element_ty)
@@ -140,7 +144,7 @@ class _matmul(torch.autograd.Function):
_locks = {}
@staticmethod
def _call(a, b, dot_out_dtype, allow_tf32):
def _call(a, b, dot_out_dtype, allow_tf32, fp8_fast_accum):
device = a.device
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
@@ -155,6 +159,8 @@ class _matmul(torch.autograd.Function):
if a.dtype in [tl.float8e4nv, tl.float8e4b15, tl.float8e5] or\
b.dtype in [tl.float8e4nv, tl.float8e4b15, tl.float8e5]:
c_dtype = torch.float16
elif a.dtype in [torch.int8] or b.dtype in [torch.int8]:
c_dtype = torch.int32
else:
c_dtype = get_higher_dtype(a.dtype, b.dtype)
c = torch.empty((M, N), device=device, dtype=c_dtype)
@@ -174,6 +180,8 @@ class _matmul(torch.autograd.Function):
ab_dtype = True
if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [tl.float8e4nv, tl.float8e5]:
ab_dtype = False
if a.dtype in [torch.int8] and b.dtype in [torch.int8]:
ab_dtype = False
# launch kernel
grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, M, N, K,
@@ -182,12 +190,13 @@ class _matmul(torch.autograd.Function):
c.stride(0), c.stride(1),
dot_out_dtype=dot_out_dtype,
allow_tf32=allow_tf32,
fp8_fast_accum=fp8_fast_accum,
GROUP_M=8, AB_DTYPE=ab_dtype)
return c
@staticmethod
def forward(ctx, a, b, dot_out_dtype=None, allow_tf32=True):
return _matmul._call(a, b, dot_out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
def forward(ctx, a, b, dot_out_dtype=None, allow_tf32=True, fp8_fast_accum=True):
return _matmul._call(a, b, dot_out_dtype=dot_out_dtype, allow_tf32=allow_tf32, fp8_fast_accum=fp8_fast_accum)
matmul = _matmul.apply

View File

@@ -5,14 +5,16 @@ import torch
from .. import cdiv
from .._C.libtriton.triton import runtime
from ..runtime import driver
from ..testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops
from ..testing import (get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops,
nvsmi)
def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
''' return compute throughput in TOPS '''
total_warps = num_ctas * min(num_warps, 4)
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, backend, device)
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, cur_sm_clock, backend, device)
return tflops
@@ -20,7 +22,8 @@ def get_simd_tflops(backend, device, num_ctas, num_warps, dtype):
''' return compute throughput in TOPS '''
total_warps = num_ctas * min(num_warps, 4)
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, backend, device)
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, cur_sm_clock, backend, device)
return tflops

View File

@@ -132,9 +132,11 @@ class Autotuner(KernelInterface):
full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs}
if config.pre_hook is not None:
config.pre_hook(full_nargs)
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization, **kwargs, **config.kwargs)
ret = self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization, **kwargs, **config.kwargs)
self.nargs = None
return ret
def prune_configs(self, kwargs):
pruned_configs = self.configs

View File

@@ -11,7 +11,10 @@ static inline void gpuAssert(CUresult code, const char *file, int line) {
char err[1024] = {0};
strcat(err, prefix);
strcat(err, str);
PyGILState_STATE gil_state;
gil_state = PyGILState_Ensure();
PyErr_SetString(PyExc_RuntimeError, err);
PyGILState_Release(gil_state);
}
}
@@ -327,7 +330,7 @@ static PyObject *memFree(PyObject *self, PyObject *args) {
// Helper function to convert a Python list to a cuuint64_t array
static cuuint64_t *list_to_cuuint64_array(PyObject *listObj) {
Py_ssize_t len = PyList_Size(listObj);
cuuint64_t *array = malloc(len * sizeof(cuuint64_t));
cuuint64_t *array = (cuuint64_t *)malloc(len * sizeof(cuuint64_t));
for (Py_ssize_t i = 0; i < len; i++) {
PyObject *item = PyList_GetItem(listObj, i);
array[i] = (cuuint64_t)PyLong_AsUnsignedLongLong(item);
@@ -338,7 +341,7 @@ static cuuint64_t *list_to_cuuint64_array(PyObject *listObj) {
// Helper function to convert a Python list to a cuuint32_t array
static cuuint32_t *list_to_cuuint32_array(PyObject *listObj) {
Py_ssize_t len = PyList_Size(listObj);
cuuint32_t *array = malloc(len * sizeof(cuuint32_t));
cuuint32_t *array = (cuuint32_t *)malloc(len * sizeof(cuuint32_t));
for (Py_ssize_t i = 0; i < len; i++) {
PyObject *item = PyList_GetItem(listObj, i);
array[i] = (cuuint32_t)PyLong_AsUnsignedLong(item);

View File

@@ -13,7 +13,10 @@ static inline void gpuAssert(hipError_t code, const char *file, int line) {
const char *str = hipGetErrorString(code);
char err[1024] = {0};
snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str);
PyGILState_STATE gil_state;
gil_state = PyGILState_Ensure();
PyErr_SetString(PyExc_RuntimeError, err);
PyGILState_Release(gil_state);
}
}
}

View File

@@ -10,6 +10,14 @@ def default_cache_dir():
return os.path.join(Path.home(), ".triton", "cache")
def default_override_dir():
return os.path.join(Path.home(), ".triton", "override")
def default_dump_dir():
return os.path.join(Path.home(), ".triton", "dump")
class CacheManager(ABC):
def __init__(self, key):
pass
@@ -36,17 +44,26 @@ class CacheManager(ABC):
class FileCacheManager(CacheManager):
def __init__(self, key):
def __init__(self, key, override=False, dump=False):
self.key = key
self.lock_path = None
# create cache directory if it doesn't exist
self.cache_dir = os.getenv('TRITON_CACHE_DIR', "").strip() or default_cache_dir()
if self.cache_dir:
if (dump):
self.cache_dir = default_dump_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
elif (override):
self.cache_dir = default_override_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
else:
raise RuntimeError("Could not create or locate cache dir")
# create cache directory if it doesn't exist
self.cache_dir = os.getenv('TRITON_CACHE_DIR', "").strip() or default_cache_dir()
if self.cache_dir:
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
else:
raise RuntimeError("Could not create or locate cache dir")
def _make_path(self, filename) -> str:
return os.path.join(self.cache_dir, filename)
@@ -131,3 +148,11 @@ def get_cache_manager(key) -> CacheManager:
__cache_cls_nme = user_cache_manager
return __cache_cls(key)
def get_override_manager(key) -> CacheManager:
return __cache_cls(key, override=True)
def get_dump_manager(key) -> CacheManager:
return __cache_cls(key, dump=True)

View File

@@ -0,0 +1,527 @@
import inspect
import numpy as np
import triton
import triton.language as tl
from .._C.libtriton.triton import interpreter as _interpreter
# TODO: duplicate
def str_to_ty(name):
language = tl
if name[0] == "*":
ty = str_to_ty(name[1:])
return language.pointer_type(ty)
tys = {
"fp8e4nv": language.float8e4nv,
"fp8e5": language.float8e5,
"fp8e4b15": language.float8e4b15,
"fp8e4b15x4": language.float8e4b15x4,
"fp16": language.float16,
"bf16": language.bfloat16,
"fp32": language.float32,
"fp64": language.float64,
"i1": language.int1,
"i8": language.int8,
"i16": language.int16,
"i32": language.int32,
"i64": language.int64,
"u8": language.uint8,
"u16": language.uint16,
"u32": language.uint32,
"u64": language.uint64,
"B": language.int1,
}
return tys[name]
class TensorHandle:
def __init__(self, data, dtype):
self.data = data
self.dtype = dtype
def __bool__(self):
return bool(self.data.all())
class BlockPointerHandle:
def __init__(self, base, shape, strides, offsets, tensor_shape, order):
self.base = base
self.shape = shape
self.strides = strides
self.offsets = offsets
self.tensor_shape = tensor_shape
self.order = order
def materialize_pointers(self, boundary_check):
dtype_tt = self.base.dtype.element_ty
n_bytes = dtype_tt.primitive_bitwidth // 8
tensor_shape = self.tensor_shape
ptrs = np.broadcast_to(self.base.data, self.tensor_shape)
masks = np.ones(self.tensor_shape, dtype=bool)
for dim in range(len(tensor_shape)):
bcast_dims = [1] * len(tensor_shape)
bcast_dims[dim] = tensor_shape[dim]
off = (self.offsets[dim].data + np.arange(tensor_shape[dim])).reshape(bcast_dims)
ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64)
if dim in boundary_check:
masks = np.logical_and(masks, off < self.shape[dim].data)
ptrs = TensorHandle(ptrs, self.base.dtype)
return ptrs, masks
def wrap_ret(compute_ret_ty):
def wrapper(fn):
def wrapped(*args, **kwargs):
ret = fn(*args, **kwargs)
return TensorHandle(ret.data, compute_ret_ty(*args, **kwargs))
return wrapped
return wrapper
class Builder:
def __init__(self) -> None:
self.arch = None
# pass
def set_grid_idx(self, x, y, z):
assert x < self.grid_dim[0]
assert y < self.grid_dim[1]
assert z < self.grid_dim[2]
self.grid_idx = (x, y, z)
def set_grid_dim(self, nx, ny, nz):
self.grid_dim = (nx, ny, nz)
def np_dtype(self, tt_dtype):
if isinstance(tt_dtype, tl.pointer_type):
return np.dtype(np.uint64)
np_types = {
tl.float16: np.dtype(np.float16),
tl.float32: np.dtype(np.float32),
tl.float64: np.dtype(np.float64),
tl.int8: np.dtype(np.int8),
tl.uint8: np.dtype(np.uint8),
tl.int16: np.dtype(np.int16),
tl.uint16: np.dtype(np.uint16),
tl.int32: np.dtype(np.int32),
tl.uint32: np.dtype(np.uint32),
tl.int64: np.dtype(np.int64),
tl.uint64: np.dtype(np.uint64),
}
return np_types[tt_dtype]
# constants
def get_half_ty(self):
return tl.float16
def get_float_ty(self):
return tl.float32
def get_int64_ty(self):
return tl.int64
def get_ptr_ty(self, elt_ty, addr_space):
return tl.pointer_type(elt_ty, addr_space)
def get_block_ty(self, dtype, shape):
return tl.tensor(shape, dtype)
def get_int32(self, value):
return TensorHandle(np.array([value], dtype=np.int32), tl.int32)
def get_int64(self, value):
return TensorHandle(np.array([value], dtype=np.int64), tl.int64)
def get_fp16(self, value):
return TensorHandle(np.array([value], dtype=np.float16), tl.float16)
def get_fp32(self, value):
return TensorHandle(np.array([value], dtype=np.float32), tl.float32)
def get_null_value(self, type):
return TensorHandle(np.array([0], dtype=self.np_dtype(type)), type)
# programming model
def create_get_program_id(self, axis):
assert self.grid_idx is not None
return TensorHandle(np.array([self.grid_idx[axis]], dtype=np.int32), tl.int32)
def create_get_num_programs(self, axis):
return TensorHandle(np.array([self.grid_dim[axis]], dtype=np.int32), tl.int32)
# memory ops
def create_load(self, ptr, _0, _1, is_volatile):
mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1)
other = None
return self.create_masked_load(ptr, mask, other, _0, _1, is_volatile)
def create_store(self, ptr, val, _0, _1):
mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1)
return self.create_masked_store(ptr, val, mask, None, None)
def create_masked_load(self, ptrs, mask, other, cache_modifier, eviction_policy, is_volatile):
dtype_tt = ptrs.dtype.element_ty
dtype_np = self.np_dtype(dtype_tt)
if other is None:
other = TensorHandle(np.ones_like(ptrs.data, dtype=dtype_np), dtype_tt)
ret = _interpreter.load(ptrs.data, mask.data, other.data, dtype_np)
return TensorHandle(ret, dtype_tt)
def create_masked_store(self, ptrs, value, mask, cache_modifier, eviction_policy):
return _interpreter.store(ptrs.data, value.data, mask.data)
# casting ops
def cast_impl(self, src, dst_type):
if isinstance(dst_type, tl.tensor):
dst_type = dst_type.dtype
return TensorHandle(src.data.astype(self.np_dtype(dst_type)), dst_type)
create_si_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type)
create_ui_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type)
create_fp_to_si = lambda self, src, dst_type: self.cast_impl(src, dst_type)
create_fp_to_ui = lambda self, src, dst_type: self.cast_impl(src, dst_type)
create_fp_ext = lambda self, src, dst_type: self.cast_impl(src, dst_type)
create_fp_trunc = lambda self, src, dst_type: self.cast_impl(src, dst_type)
create_int_cast = lambda self, src, dst_type, is_signed: self.cast_impl(src, dst_type)
def create_fp_to_fp(self, src, dst_type):
assert "float8 not NotImplemented yet"
def create_bitcast(self, src, dst_type):
return TensorHandle(src.data.view(self.np_dtype(dst_type)), dst_type)
# binary operators
def binary_op(self, lhs, rhs, op):
return TensorHandle(op(lhs.data, rhs.data), lhs.dtype)
create_fadd = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add)
create_fmul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply)
create_fdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide)
create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder)
create_fsub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract)
create_mul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply)
create_sdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.floor_divide)
create_udiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.floor_divide)
create_srem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder)
create_urem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder)
create_add = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add)
create_sub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract)
create_shl = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.left_shift)
create_lshr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift)
create_ashr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift)
create_minsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
create_minui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
create_minf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
create_maxsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
create_maxui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
create_maxf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
create_icmpSLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
create_icmpSLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
create_icmpSGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
create_icmpSGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
create_icmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
create_icmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
create_icmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
create_icmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
create_icmpEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal)
create_icmpNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal)
create_fcmpOLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
create_fcmpOGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
create_fcmpOLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
create_fcmpOGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
create_fcmpOEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal)
create_fcmpONE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal)
create_fcmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
create_fcmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
create_fcmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
create_fcmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
create_fcmpUEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal)
create_fcmpUNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal)
create_and = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_and)
create_xor = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_xor)
create_or = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_or)
# ternary functions
def ternary_op(self, lhs, rhs, other, op):
return TensorHandle(op(lhs.data, rhs.data, other.data), other.dtype)
create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where)
# unary functions
def unary_op(self, arg, op):
return TensorHandle(op(arg.data), arg.dtype)
create_exp = lambda self, arg: self.unary_op(arg, np.exp)
create_cos = lambda self, arg: self.unary_op(arg, np.cos)
create_sin = lambda self, arg: self.unary_op(arg, np.sin)
create_log = lambda self, arg: self.unary_op(arg, np.log)
create_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt)
create_fabs = lambda self, arg: self.unary_op(arg, np.abs)
create_iabs = lambda self, arg: self.unary_op(arg, np.abs)
# tensor operators
create_dot = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.dot)
create_view = lambda self, arg, shape: TensorHandle(arg.data.reshape(shape), arg.dtype)
create_trans = lambda self, arg: self.unary_op(arg, np.transpose)
def create_dot(self, a, b, d, allow_tf32, maxNumImpreciseAcc):
return TensorHandle(np.dot(a.data, b.data) + d.data, a.dtype)
def create_make_range(self, start, stop):
return TensorHandle(np.arange(start, stop, dtype=np.int32), tl.int32)
# pointer arithmetic
def create_addptr(self, ptr, offset):
dtype_tt = ptr.dtype.element_ty
return TensorHandle(ptr.data + (dtype_tt.primitive_bitwidth // 8) * offset.data.astype(np.uint64), ptr.dtype)
def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, is_volatile):
ptrs, masks = ptr.materialize_pointers(boundary_check)
assert padding_option is None
other = None
return self.create_masked_load(ptrs, masks, other, cache_modifier, eviction_policy, is_volatile)
def create_tensor_pointer_store(self, ptr, value, boundary_check, cache_modifier, eviction_policy):
ptrs, masks = ptr.materialize_pointers(boundary_check)
return self.create_masked_store(ptrs, value, masks, cache_modifier, eviction_policy)
def create_expand_dims(self, arg, axis):
return TensorHandle(np.expand_dims(arg.data, axis), arg.dtype)
def create_broadcast(self, arg, shape):
return TensorHandle(np.broadcast_to(arg.data, shape), arg.dtype)
def create_int_to_ptr(self, val, dst_ty):
return TensorHandle(val.data.astype(np.uint64), dst_ty)
# def create_cat(self, lhs, rhs):
# pass
# def create_broadcast(self, arg, shape):
# pass
def create_splat(self, arg, shape):
return TensorHandle(np.full(shape, arg.data[0], dtype=self.np_dtype(arg.dtype)), arg.dtype)
# def create_atomic_cas(self, ptr, cmp, val, sem):
# pass
# def create_atomic_rmw(self, rmwOp, ptr, val, mask, sem):
# pass
# def create_extern_elementwise(self, libName, libPath, symbol, argList, retType, isPure):
# pass
# def create_reduce(self, operands, axis):
# pass
# def create_reduce_ret(self, args):
# pass
# def create_scan(self, operands, axis):
# pass
# def create_scan_ret(self, args):
# pass
# def create_ptr_to_int(self, val, type):
# pass
# def create_int_to_ptr(self, val, type):
# pass
# def create_inline_asm(self, inlineAsm, constraints, values, type, isPure, pack):
# pass
# def create_print(self, prefix, values):
# pass
# def create_assert(self, condition, message, fileName, funcName, lineNo):
# pass
# def create_undef(self, type):
# pass
# def create_barrier(self):
# pass
def create_make_block_ptr(self, base, shape, strides, offsets, tensor_shape, order):
return BlockPointerHandle(base, shape, strides, np.array(offsets), tensor_shape, order)
def create_advance(self, ptr, offsets):
assert len(ptr.offsets) == len(offsets)
ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, ptr.offsets, ptr.tensor_shape, ptr.order)
for i in range(len(offsets)):
ret.offsets[i].data += offsets[i].data
return ret
def patch_attr(obj, name, member, builder):
new_member = lambda *args, member=member, **kwargs: (member(*args, **{k: v for k, v in kwargs.items() if k != '_builder'}, _builder=builder))
setattr(obj, name, new_member)
def _patch_lang_tensor(tensor, builder):
for name, member in inspect.getmembers(tensor):
if tl.core.is_builtin(member):
patch_attr(tensor, name, member, builder)
tensor.__index__ = lambda self: int(self.handle.data)
tensor.__bool__ = lambda self: True
tensor.__str__ = lambda self: str(self.handle.data)
tensor.__getitem__ = lambda self, slices: self.handle.data.__getitem__(slices)
def _patch_lang_core(lang, builder):
for name, member in inspect.getmembers(lang):
if tl.core.is_builtin(member):
patch_attr(lang, name, member, builder)
# reduce is better off with a separate patch due to how
# the builder currently interfaces with custom functions
def _new_reduce(input, axis, combine_fn):
fn = combine_fn.fn.__name__
mapping = {
'maximum': np.max,
'_sum_combine': np.sum,
}
ret = mapping[fn](input.handle.data, axis=axis)
ret_type = tl.block_type(input.dtype, ret.shape)
return tl.core.tensor(TensorHandle(ret, input.dtype), ret_type)
lang.reduce = _new_reduce
def _patch_lang_math(lang, builder):
math = lang.math
mapping = {
'abs': 'abs',
'acos': 'arccos',
'asin': 'arcsin',
'exp2': 'exp2',
'log2': 'log2',
'max': 'maximum',
}
def make_numpy(name):
def impl(*args, **kwargs):
ret_type = args[0].type # TODO: incorrect
ret_dtype = args[0].dtype # TODO: incorrect
args = [arg.handle.data for arg in args]
kwargs = {k: v.handle.data for k, v in kwargs.items()}
ret = getattr(np, mapping[name])(*args, **kwargs)
ret = tl.core.tensor(TensorHandle(ret, ret_dtype), ret_type)
return ret
return impl
def make_fallback(name):
def fallback(*args, **kwargs):
raise NotImplementedError(f"""
{name} not supported in interpreter mode: no known numpy implementation.
If you think that {name} in fact does have a numpy implementation, please add it
to the mapping in python/triton/interpreter/new_interpreter.py:_patch_lang_math.
""")
return fallback
for name, member in inspect.getmembers(math):
if name in mapping:
setattr(math, name, make_numpy(name))
else:
setattr(math, name, make_fallback(name))
# TODO: wrap everything in triton tensors
def _implicit_cvt(arg):
if isinstance(arg, int):
ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
handle = TensorHandle(np.array([arg], dtype=np.int32), ty)
return tl.tensor(handle, ty)
if hasattr(arg, 'data_ptr'):
ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty)
return tl.tensor(handle, ty)
return arg
def _unwrap(tensor):
if isinstance(tensor, triton.TensorWrapper):
return tensor.base
return tensor
builder = Builder()
RESERVED_KWS = ['num_warps', 'num_stages', 'num_ctas', 'enable_warp_specialization']
class GridExecutor:
def __init__(self, fn, arg_names, grid):
from .jit import _normalize_ty # TODO: modularize
self.fn = fn
self.arg_names = arg_names
self.grid = grid
__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
self.constexprs = [name for name in arg_names if __annotations__.get(name) == 'constexpr']
def _patch_lang(self, builder):
lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]]
assert len(lang) == 1, "triton.language must be visible from within jit'd function"
_patch_lang_tensor(getattr(lang[0], 'tensor'), builder)
_patch_lang_core(lang[0], builder)
_patch_lang_math(lang[0], builder)
def __call__(self, *args_dev, **kwargs):
args_hst = [_unwrap(arg).cpu() if hasattr(arg, 'data_ptr') else arg for arg in args_dev]
# removes reserved keywords from kwargs
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS}
# remaps core language functions to interpreted ones
self._patch_lang(builder)
# we need to copy arguments to the host for the interpreter
# implicitly convert tensor arguments to their base pointers
args = inspect.getcallargs(self.fn, *args_hst, **kwargs)
args = {name: arg if name in self.constexprs else _implicit_cvt(arg) for name, arg in args.items()}
# iterate through grid
grid = self.grid(args) if callable(self.grid) else self.grid
assert len(grid) <= 3
grid = grid + (1,) * (3 - len(grid))
builder.set_grid_dim(*grid)
for x in range(grid[0]):
for y in range(grid[1]):
for z in range(grid[2]):
builder.set_grid_idx(x, y, z)
self.fn(**args)
# copy arguments back to propagate side-effects
for arg_dev, arg_hst in zip(args_dev, args_hst):
if hasattr(arg_dev, 'data_ptr'):
_unwrap(arg_dev).copy_(arg_hst.to(arg_dev.device))
class InterpretedFunction:
def _patch_lang(self, builder):
lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]]
assert len(lang) == 1, "triton.language must be visible from within jit'd function"
_patch_lang_tensor(getattr(lang[0], 'tensor'), builder)
_patch_lang_core(lang[0], builder)
def __init__(self, fn) -> None:
self.fn = fn
def run(*args, **kwargs):
grid = kwargs['grid']
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS + ['grid']}
return GridExecutor(self.fn, self.arg_names, grid)(*args, **kwargs)
self.run = run
signature = inspect.signature(fn)
self.arg_names = [v.name for v in signature.parameters.values()]
def __getitem__(self, grid):
return GridExecutor(self.fn, self.arg_names, grid)
def __call__(self, *args, **kwargs):
self._patch_lang(builder)
return self.fn(*args, **kwargs)

View File

@@ -14,6 +14,7 @@ from typing import (Callable, Generic, Iterable, List, Optional, TypeVar, Union,
from .._C.libtriton.triton import TMAInfos
from ..common.backend import get_backend, path_to_ptxas
from ..language.core import dtype
from .interpreter import InterpretedFunction
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
TRITON_VERSION = "2.1.0"
@@ -256,6 +257,8 @@ class JITFunction(KernelInterface[T]):
"float8_e5m2fnuz": "fp8e5b16",
"float8e4b15": "fp8e4b15",
"float8e4b15x4": "fp8e4b15x4",
"float8_e4m3fn": "fp8e4nv",
"float8_e5m2": "fp8e5",
"float16": "fp16",
"bfloat16": "bf16",
"float32": "fp32",
@@ -274,10 +277,6 @@ class JITFunction(KernelInterface[T]):
tys[v] = v
return key if isinstance(key, str) else f"*{tys[dtype_str]}"
def _make_signature(self, sig_key):
signature = ",".join([self._type_of(k) for i, k in enumerate(sig_key)])
return signature
def _make_constants(self, constexpr_key):
constants = dict(zip(self.constexprs, constexpr_key))
return constants
@@ -304,29 +303,29 @@ class JITFunction(KernelInterface[T]):
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={
"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False)
def _get_arg_specialization_key(self, arg) -> str:
arg_annotation = self.__annotations__.get(arg, '')
def _get_arg_specialization_key(self, arg_name, arg):
arg_annotation = self.__annotations__.get(arg_name, '')
if arg_annotation == '':
return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") \
else ({arg} % {JITFunction.divisibility} == 0, {arg} % {JITFunction.divisibility_8} == 0, {arg} == 1) if isinstance({arg}, int) \
else (False,)'
return (arg.data_ptr() % JITFunction.divisibility == 0) if hasattr(arg, "data_ptr") \
else (arg % JITFunction.divisibility == 0, arg % JITFunction.divisibility_8 == 0, arg == 1) if isinstance(arg, int) \
else (False,)
elif 'Tensor' in arg_annotation:
return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0)'
elif arg_annotation == 'int':
return f'({arg} % {JITFunction.divisibility} == 0, {arg} % {JITFunction.divisibility_8} == 0, {arg} == 1)'
return (arg.data_ptr() % JITFunction.divisibility == 0)
elif 'int' in arg_annotation or 'bool' in arg_annotation:
return (arg % JITFunction.divisibility == 0, arg % JITFunction.divisibility_8 == 0, arg == 1)
else:
return '(False,)'
return (False,)
def _get_arg_sig_key(self, arg) -> str:
arg_annotation = self.__annotations__.get(arg, '')
def _get_arg_sig_key(self, arg_name, arg) -> str:
arg_annotation = self.__annotations__.get(arg_name, '')
if 'Tensor' in arg_annotation:
return f'{arg}.dtype'
return arg.dtype
elif arg_annotation == 'bool':
return "i1"
elif arg_annotation == 'float':
return 'fp32'
else:
return f'_key_of({arg})'
return self._key_of(arg)
def _conclude_device_type(self, device_types: List[str], pinned_memory_flags: List[bool]) -> str:
device_types = [device_type for device_type in device_types if device_type != '']
@@ -344,32 +343,110 @@ class JITFunction(KernelInterface[T]):
return device_types[0] if len(device_types) > 0 else 'cuda'
def _make_launcher(self):
regular_args = [f'{arg}' for i, arg in enumerate(
regular_args = [arg for i, arg in enumerate(
self.arg_names) if i not in self.constexprs]
constexpr_args = [
f'{arg}' for i, arg in enumerate(
self.arg_names) if i in self.constexprs]
args = ', '.join(regular_args)
# cache key for regular argument type
sig_keys = ', '.join([self._get_arg_sig_key(arg) for arg in regular_args])
device_types = '[' + ', '.join([f'_device_of({arg})' for arg in regular_args]) + ']'
pinned_memory_flags = '[' + ', '.join([f'_pinned_memory_of({arg})' for arg in regular_args]) + ']'
# cache key for constexpr argument values
constexpr_keys = ', '.join(constexpr_args)
# cache key for argument specialization
specializations = []
for i, arg in enumerate(regular_args):
if i in self.do_not_specialize:
continue
specializations += [self._get_arg_specialization_key(arg)]
constexpr_args = [arg for i, arg in enumerate(
self.arg_names) if i in self.constexprs]
spec_keys = ', '.join(specializations)
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
def regular_args_v(args_proxy):
return [args_proxy[arg_name] for arg_name in regular_args]
def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, stream, warmup, device, device_type):
from ..compiler import (CompiledKernel, compile,
get_arch_default_num_stages,
get_arch_default_num_warps)
sig_key = tuple([self._get_arg_sig_key(arg_name, args_proxy[arg_name]) for arg_name in regular_args])
constexpr_key = tuple([args_proxy[arg_name] for arg_name in constexpr_args])
specializations = []
for i, arg_name in enumerate(regular_args):
if i in self.do_not_specialize:
continue
specializations += [self._get_arg_specialization_key(arg_name, args_proxy[arg_name])]
spec_key = tuple(specializations)
assert num_ctas > 0
assert grid is not None
if callable(grid):
grid = grid(args_proxy)
grid_size = len(grid)
grid_0 = grid[0]
grid_1 = grid[1] if grid_size > 1 else 1
grid_2 = grid[2] if grid_size > 2 else 1
if device_type is None:
device_types = [self._device_of(arg) for arg in regular_args_v(args_proxy)]
device_types = [_device_type for _device_type in device_types if _device_type != '']
device_type = self._conclude_device_type(device_types, [self._pinned_memory_of(arg) for arg in
regular_args_v(args_proxy)])
device_backend = None
if device_type not in ['cuda']:
device_backend = get_backend(device_type)
if device_backend is None:
raise ValueError('Cannot find backend for ' + device_type)
if device is None:
if device_type in ['cuda']:
device = get_current_device()
set_current_device(device)
else:
device = device_backend.get_current_device()
device_backend.set_current_device(device)
if stream is None and not warmup:
if device_type in ['cuda']:
stream = get_cuda_stream(device)
else:
stream = device_backend.get_stream()
if num_warps is None:
num_warps = get_arch_default_num_warps(device_type)
if num_stages is None:
num_stages = get_arch_default_num_stages(device_type)
key = (version_key(), sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, self.debug)
if extern_libs is not None:
key = (key, tuple(extern_libs.items()))
bin = self.cache[device].get(key, None)
if bin is not None:
# build dict of constant values
args = regular_args_v(args_proxy)
# Create tensormaps and append to args
args = bin.assemble_tensormap_to_arg(args)
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
return bin
# kernel not cached -- compile
else:
# build dict of constant values
args = regular_args_v(args_proxy)
all_args = tuple([args_proxy[arg_name] for arg_name in self.arg_names])
configs = self._get_config(*all_args),
constants = self._make_constants(constexpr_key)
constants.update({i: None for i, arg in enumerate(all_args) if arg is None})
constants.update({i: 1 for i in configs[0].equal_to_1})
# build kernel signature -- doesn't include specialized arguments
signature = {i: self._type_of(self._key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs}
# build stub signature -- includes arguments that are specialized
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {i} is not supported")
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, configs):
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
# Create tensormaps and append to args
args = bin.assemble_tensormap_to_arg(args)
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
self.cache[device][key] = bin
return bin
return None
# create a wrapper to call launcher_body
args_map = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
args_signature = ', '.join(name if dflt == inspect._empty else f'{name} = triton.language.dtype(\'{dflt}\')' if dtype.is_dtype(f'{dflt}') else f'{name} = {dflt}' for name, dflt in zip(self.arg_names, self.arg_defaults))
args_signature = args_signature + ', ' if len(args_signature) > 0 else ''
src = f"""
import triton
<<<<<<< HEAD
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, waves_per_eu=0, matrix_instr_nonkdim=0, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
from ..compiler import compile, CompiledKernel, get_arch_default_num_warps, get_arch_default_num_stages
sig_key = {f'{sig_keys},' if len(sig_keys) > 0 else ()}
@@ -449,19 +526,12 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
self.cache[device][key] = bin
return bin
return None
=======
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
return launcher_body({{{args_map}}}, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, stream, warmup, device, device_type)
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
"""
scope = {"version_key": version_key(),
"get_cuda_stream": get_cuda_stream,
"self": self,
"_spec_of": self._spec_of,
"_key_of": self._key_of,
"_device_of": self._device_of,
"_pinned_memory_of": self._pinned_memory_of,
"cache": self.cache,
"__spec__": __spec__,
"get_backend": get_backend,
"get_current_device": get_current_device,
"set_current_device": set_current_device}
scope = {"launcher_body": launcher_body}
exec(src, scope)
return scope[self.fn.__name__]
@@ -572,7 +642,6 @@ def jit(
do_not_specialize: Optional[Iterable[int]] = None,
debug: Optional[bool] = None,
noinline: Optional[bool] = None,
interpret: Optional[bool] = None,
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
"""
Decorator for JIT-compiling a function using the Triton compiler.
@@ -594,9 +663,8 @@ def jit(
def decorator(fn: T) -> JITFunction[T]:
assert callable(fn)
if interpret:
from ..interpreter.interpreter import GridSelector
return GridSelector(fn)
if os.getenv("TRITON_INTERPRET", "0") == "1":
return InterpretedFunction(fn)
else:
return JITFunction(
fn,

View File

@@ -32,8 +32,11 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None):
"""
if torch.cuda.current_stream() == torch.cuda.default_stream():
raise RuntimeError("Cannot capture graph in default stream. Please use side stream in benchmark code.")
# record CUDAGraph
# warmup
fn()
# step 1 - we estimate the amount of time the kernel call takes
# NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point
# but it is probably good enough
if grad_to_none is not None:
for x in grad_to_none:
x.detach_()
@@ -43,39 +46,35 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None):
with torch.cuda.graph(g):
fn()
torch.cuda.synchronize()
fn = lambda: g.replay()
# Estimate the runtime of the function
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
fn()
g.replay()
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event)
# compute number of repetition to last `rep` ms
n_repeat = max(1, int(rep / estimate_ms))
# compute number of repetition to last `rep` ms
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
ret = []
n_retries = 50
for _ in range(n_retries):
# Benchmark
torch.cuda.synchronize()
# step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize
# host overhead
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
for i in range(n_repeat):
# we don't want `fn` to accumulate gradient values
# if it contains a backward pass. So we clear the
# provided gradients
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
# record time of `fn`
start_event[i].record()
fn()
end_event[i].record()
torch.cuda.synchronize()
# measure time and return
ret = []
n_retries = 10
for i in range(n_retries):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
g.replay()
end_event.record()
torch.cuda.synchronize()
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)])
ret.append(torch.min(times))
ret += [start_event.elapsed_time(end_event) / n_repeat]
return torch.mean(torch.tensor(ret)).item()
@@ -266,7 +265,7 @@ class Mark:
self.fn = fn
self.benchmarks = benchmarks
def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool):
def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, **kwrags):
import os
import matplotlib.pyplot as plt
@@ -287,7 +286,7 @@ class Mark:
row_mean, row_min, row_max = [], [], []
for y in bench.line_vals:
ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args)
ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags)
try:
y_mean, y_min, y_max = ret
except TypeError:
@@ -328,14 +327,14 @@ class Mark:
if save_path:
df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False)
def run(self, show_plots=False, print_data=False, save_path=''):
def run(self, show_plots=False, print_data=False, save_path='', **kwargs):
has_single_bench = isinstance(self.benchmarks, Benchmark)
benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks
if save_path:
html = open(os.path.join(save_path, "results.html"), "w")
html.write("<html><body>\n")
for bench in benchmarks:
self._run(bench, save_path, show_plots, print_data)
self._run(bench, save_path, show_plots, print_data, **kwargs)
if save_path:
html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
if save_path:
@@ -368,7 +367,7 @@ def get_dram_gbps(backend=None, device=None):
return bw_gbps
def get_max_tensorcore_tflops(dtype, backend=None, device=None, clock_rate=None):
def get_max_tensorcore_tflops(dtype, clock_rate, backend=None, device=None):
import torch
from .runtime import driver
@@ -378,8 +377,6 @@ def get_max_tensorcore_tflops(dtype, backend=None, device=None, clock_rate=None)
device = torch.cuda.current_device()
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4
if not clock_rate:
clock_rate = driver.utils.get_device_properties(device)["sm_clock_rate"] # in kHz
capability = torch.cuda.get_device_capability(device)
if capability[0] < 8:
assert dtype == torch.float16
@@ -423,21 +420,6 @@ def cuda_memcheck(**target_kwargs):
return decorator
def nvsmi_attr(attrs):
attrs = ",".join(attrs)
cmd = [
"nvidia-smi",
"-i",
"0",
"--query-gpu=" + attrs,
"--format=csv,noheader,nounits",
]
out = subprocess.check_output(cmd)
ret = out.decode(sys.stdout.encoding).split(",")
ret = [int(x) for x in ret]
return ret
@contextmanager
def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215):
try:
@@ -458,8 +440,8 @@ def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215):
f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}",
]
)
cur_sm_clock = nvsmi_attr(["clocks.current.sm"])[0]
cur_mem_clock = nvsmi_attr(["clocks.current.memory"])[0]
cur_sm_clock = nvsmi(["clocks.current.sm"])[0]
cur_mem_clock = nvsmi(["clocks.current.memory"])[0]
assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz"
assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz"
tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock
@@ -471,7 +453,7 @@ def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215):
subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"])
def get_max_simd_tflops(dtype, backend=None, device=None):
def get_max_simd_tflops(dtype, clock_rate, backend=None, device=None):
import torch
from .runtime import driver
@@ -481,7 +463,6 @@ def get_max_simd_tflops(dtype, backend=None, device=None):
device = torch.cuda.current_device()
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4
clock_rate = driver.utils.get_device_properties(device)["sm_clock_rate"] # in kHz
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
if dtype == torch.float32:

View File

@@ -20,7 +20,7 @@ data along with utilities to load, unload and launch the kernel.
signature is provided as a list of (optionally divisibility-hinted) types
or constexpr values, e.g.
`compile.py --kernel-name kernel --signature "*f32:16, i32:16, 1024, i32" --out-name kernel /path/to/kernel.py`
`compile.py --kernel-name kernel --signature "*fp32:16, i32:16, 1024, i32" --out-name kernel /path/to/kernel.py`
will compile triton.JITFunction of name `kernel` inside the file `/path/to/kernel.py`.
Said kernel will be specialized such that argument 0, 1 are assumed to be multiple of 16,
@@ -51,7 +51,7 @@ if __name__ == "__main__":
args = parser.parse_args()
out_name = args.out_name if args.out_name else args.kernel_name
out_path = args.out_path if args.out_path else out_name
out_path = args.out_path if args.out_path else Path(out_name)
# execute python sources and extract functions wrapped in JITFunction
arg_path = Path(args.path)

View File

@@ -20,8 +20,13 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import functools
import os
import re
import subprocess
import tempfile
from ..common.backend import path_to_cuobjdump, path_to_nvdisasm
FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*')
SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*')
@@ -60,11 +65,26 @@ def processSassLines(fline, sline, labels):
return (f'{ctrl}', f'{asm}')
@functools.lru_cache()
def get_sass(cubin_asm, fun=None):
fd, path = tempfile.mkstemp()
try:
with open(fd, 'wb') as cubin:
cubin.write(cubin_asm)
sass = extract(path, fun)
finally:
os.remove(path)
return sass
def extract(file_path, fun):
cuobjdump, _ = path_to_cuobjdump()
nvdisasm, _ = path_to_nvdisasm()
os.environ["NVDISASM_PATH"] = nvdisasm
if fun is None:
sass_str = subprocess.check_output(["cuobjdump", "-sass", file_path])
sass_str = subprocess.check_output([cuobjdump, "-sass", file_path])
else:
sass_str = subprocess.check_output(["cuobjdump", "-fun", fun, "-sass", file_path])
sass_str = subprocess.check_output([cuobjdump, "-fun", fun, "-sass", file_path])
sass_lines = sass_str.splitlines()
line_idx = 0
while line_idx < len(sass_lines):