mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge commit 'ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33' into ifu-rebase-again
Conflicts: .gitignore .gitmodules README.md bin/triton-translate.cpp include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td include/triton/Target/AMDGCN/AMDGCNTranslation.h include/triton/Target/HSACO/HSACOTranslation.h lib/Analysis/Allocation.cpp lib/Analysis/Utility.cpp lib/Conversion/TritonGPUToLLVM/CMakeLists.txt lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/Utility.cpp lib/Conversion/TritonGPUToLLVM/Utility.h lib/Dialect/TritonGPU/IR/Dialect.cpp lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp lib/Target/HSACO/CMakeLists.txt lib/Target/HSACO/HSACOTranslation.cpp lib/Target/LLVMIR/LLVMIRTranslation.cpp python/src/triton.cc python/test/unit/language/test_core.py python/test/unit/operators/test_flash_attention.py python/triton/compiler/compiler.py python/triton/compiler/make_launcher.py python/triton/language/semantic.py python/triton/runtime/jit.py python/tutorials/06-fused-attention.py python/tutorials/11-grouped-gemm.py test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
@@ -101,20 +101,34 @@ def get_backend(device_type: str):
|
||||
return _backends[device_type] if device_type in _backends else None
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def path_to_ptxas():
|
||||
def _path_to_binary(binary: str):
|
||||
base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
|
||||
paths = [
|
||||
os.environ.get("TRITON_PTXAS_PATH", ""),
|
||||
os.path.join(base_dir, "third_party", "cuda", "bin", "ptxas")
|
||||
os.path.join(base_dir, "third_party", "cuda", "bin", binary)
|
||||
]
|
||||
|
||||
for ptxas in paths:
|
||||
ptxas_bin = ptxas.split(" ")[0]
|
||||
if os.path.exists(ptxas_bin) and os.path.isfile(ptxas_bin):
|
||||
result = subprocess.check_output([ptxas_bin, "--version"], stderr=subprocess.STDOUT)
|
||||
for p in paths:
|
||||
bin = p.split(" ")[0]
|
||||
if os.path.exists(bin) and os.path.isfile(bin):
|
||||
result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT)
|
||||
if result is not None:
|
||||
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
|
||||
if version is not None:
|
||||
return ptxas, version.group(1)
|
||||
raise RuntimeError("Cannot find ptxas")
|
||||
return p, version.group(1)
|
||||
raise RuntimeError(f"Cannot find {binary}")
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def path_to_ptxas():
|
||||
return _path_to_binary("ptxas")
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def path_to_cuobjdump():
|
||||
return _path_to_binary("cuobjdump")
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def path_to_nvdisasm():
|
||||
return _path_to_binary("nvdisasm")
|
||||
|
||||
@@ -199,7 +199,7 @@ class ContainsReturnChecker(ast.NodeVisitor):
|
||||
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, function_name, arch,
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, function_name, target,
|
||||
module=None, is_kernel=False, function_types: Optional[Dict] = None,
|
||||
debug=False, noinline=False, file_name: Optional[str] = None, begin_line=0):
|
||||
self.context = context
|
||||
@@ -208,7 +208,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# node.lineno starts from 1, so we need to subtract 1
|
||||
self.begin_line = begin_line - 1
|
||||
self.builder.set_loc(file_name, begin_line, 0)
|
||||
self.builder.arch = arch
|
||||
self.builder.target = target
|
||||
self.module = self.builder.create_module() if module is None else module
|
||||
self.function_ret_types = {} if function_types is None else function_types
|
||||
self.prototype = prototype
|
||||
@@ -912,7 +912,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
file_name, begin_line = _get_fn_file_line(fn)
|
||||
generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module,
|
||||
function_name=fn_name, function_types=self.function_ret_types, debug=debug, noinline=fn.noinline,
|
||||
file_name=file_name, begin_line=begin_line, arch=self.builder.arch)
|
||||
file_name=file_name, begin_line=begin_line, target=self.builder.target)
|
||||
generator.visit(fn.parse())
|
||||
callee_ret_type = generator.last_ret_type
|
||||
self.function_ret_types[fn_name] = callee_ret_type
|
||||
@@ -1108,7 +1108,7 @@ def kernel_suffix(signature, specialization):
|
||||
return suffix
|
||||
|
||||
|
||||
def ast_to_ttir(fn, signature, specialization, constants, debug, arch):
|
||||
def ast_to_ttir(fn, signature, specialization, constants, debug, target):
|
||||
# canonicalize signature
|
||||
if isinstance(signature, str):
|
||||
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
|
||||
@@ -1137,7 +1137,7 @@ def ast_to_ttir(fn, signature, specialization, constants, debug, arch):
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants,
|
||||
function_name=function_name, attributes=new_attrs,
|
||||
is_kernel=True, debug=debug, file_name=file_name, begin_line=begin_line,
|
||||
arch=arch)
|
||||
target=target)
|
||||
try:
|
||||
generator.visit(fn.parse())
|
||||
except CompilationError as e:
|
||||
|
||||
@@ -5,10 +5,18 @@ import hashlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
<<<<<<< HEAD
|
||||
import tempfile
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
=======
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from dataclasses import dataclass
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
|
||||
from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs,
|
||||
compile_ptx_to_cubin, get_env_vars, get_num_warps,
|
||||
@@ -20,11 +28,11 @@ from ..common.build import is_hip
|
||||
# from ..runtime import driver, jit, JITFunction
|
||||
# TODO: runtime.errors
|
||||
from ..runtime.autotuner import OutOfResources
|
||||
from ..runtime.cache import get_cache_manager
|
||||
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
|
||||
from ..runtime.driver import driver
|
||||
from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device,
|
||||
get_device_capability, version_key)
|
||||
from ..tools.disasm import extract
|
||||
from ..tools.disasm import get_sass
|
||||
from .code_generator import ast_to_ttir
|
||||
from .make_launcher import make_stub
|
||||
from .utils import (InfoFromBackendForTensorMap, TensorMapManager,
|
||||
@@ -32,6 +40,24 @@ from .utils import (InfoFromBackendForTensorMap, TensorMapManager,
|
||||
|
||||
CUDA_DEFAULT_WARP_SIZE = 32
|
||||
|
||||
@dataclass
|
||||
class CudaTargetDescriptor:
|
||||
capability: int
|
||||
num_warps: int
|
||||
|
||||
|
||||
def _is_cuda(target):
|
||||
return isinstance(target, CudaTargetDescriptor)
|
||||
|
||||
|
||||
class LazyDict(dict):
|
||||
def __getitem__(self, key):
|
||||
val = dict.__getitem__(self, key)
|
||||
if callable(val):
|
||||
return val()
|
||||
return val
|
||||
|
||||
|
||||
def inline_triton_ir(mod):
|
||||
pm = ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
@@ -40,11 +66,12 @@ def inline_triton_ir(mod):
|
||||
return mod
|
||||
|
||||
|
||||
def ttir_compute_capability_rewrite(mod, arch):
|
||||
def ttir_compute_capability_rewrite(mod, target):
|
||||
# For hardware without support, we must rewrite all load/store
|
||||
# with block (tensor) pointers into tensors of pointers
|
||||
pm = ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
<<<<<<< HEAD
|
||||
if _is_cuda(arch):
|
||||
pm.add_rewrite_tensor_pointer_pass(arch, False)
|
||||
elif is_hip():
|
||||
@@ -52,13 +79,17 @@ def ttir_compute_capability_rewrite(mod, arch):
|
||||
pm.add_rewrite_tensor_pointer_pass(capability, True)
|
||||
else:
|
||||
assert(False, "unsupported target")
|
||||
=======
|
||||
if _is_cuda(target):
|
||||
pm.add_rewrite_tensor_pointer_pass(target.capability)
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def optimize_ttir(mod, arch):
|
||||
def optimize_ttir(mod, target):
|
||||
mod = inline_triton_ir(mod)
|
||||
mod = ttir_compute_capability_rewrite(mod, arch)
|
||||
mod = ttir_compute_capability_rewrite(mod, target)
|
||||
pm = ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_inliner_pass()
|
||||
@@ -72,6 +103,7 @@ def optimize_ttir(mod, arch):
|
||||
return mod
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
def ttir_to_ttgir(mod, num_warps, warpsize, num_ctas, arch):
|
||||
pm = ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
@@ -79,21 +111,36 @@ def ttir_to_ttgir(mod, num_warps, warpsize, num_ctas, arch):
|
||||
pm.add_convert_triton_to_tritongpu_pass(num_warps, warpsize, num_ctas, 0)
|
||||
else:
|
||||
pm.add_convert_triton_to_tritongpu_pass(num_warps, warpsize, num_ctas, arch)
|
||||
=======
|
||||
def ttir_to_ttgir(mod, num_warps, num_ctas, target):
|
||||
pm = ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_convert_triton_to_tritongpu_pass(num_warps, 32, num_ctas, target.capability)
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch,
|
||||
cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, matrix_inst_type):
|
||||
=======
|
||||
def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target,
|
||||
cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue):
|
||||
is_cuda = _is_cuda(target)
|
||||
if is_cuda:
|
||||
capability = target.capability
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
pm = ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_tritongpu_coalesce_pass()
|
||||
# TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass
|
||||
pm.add_plan_cta_pass(cluster_info)
|
||||
if _is_cuda(arch):
|
||||
pm.add_tritongpu_rewrite_tensor_pointer_pass(arch)
|
||||
if is_cuda:
|
||||
pm.add_tritongpu_rewrite_tensor_pointer_pass(capability)
|
||||
pm.add_plan_cta_pass(cluster_info)
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
<<<<<<< HEAD
|
||||
if _is_cuda(arch):
|
||||
pm.add_tritongpu_accelerate_matmul_pass(arch)
|
||||
# TODO change interface of accelerate_matmul_pass
|
||||
@@ -101,6 +148,10 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch,
|
||||
matrix_core_version = gpu_matrix_core_version()
|
||||
matrix_inst_size = matrix_inst_type
|
||||
pm.add_tritonamdgpu_accelerate_matmul_pass(matrix_core_version, matrix_inst_size)
|
||||
=======
|
||||
if is_cuda:
|
||||
pm.add_tritongpu_accelerate_matmul_pass(capability)
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
if optimize_epilogue:
|
||||
pm.add_tritongpu_optimize_epilogue_pass()
|
||||
@@ -114,20 +165,25 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch,
|
||||
# it's the responsibility of the compiler to figure out the exact
|
||||
# `num_warps` to use.
|
||||
# TODO: support the case where `num_warps` from user is not 4.
|
||||
<<<<<<< HEAD
|
||||
if _is_cuda(arch) and arch // 10 >= 9 and enable_warp_specialization and num_warps == 4:
|
||||
pm.add_tritongpu_ws_feasibility_checking_pass(arch)
|
||||
=======
|
||||
if capability // 10 >= 9 and enable_warp_specialization and num_warps == 4:
|
||||
pm.add_tritongpu_ws_feasibility_checking_pass(capability)
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
pm.run(mod)
|
||||
ws_enabled = ir.is_ws_supported(mod)
|
||||
pm = ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
if ws_enabled:
|
||||
pm.add_tritongpu_wsdecomposing_pass(arch)
|
||||
pm.add_tritongpu_wspipeline_pass(
|
||||
num_stages, num_warps, arch)
|
||||
pm.add_tritongpu_wsmutex_pass(arch)
|
||||
pm.add_tritongpu_wsmaterialization_pass(arch)
|
||||
pm.add_tritongpu_wsdecomposing_pass(capability)
|
||||
pm.add_tritongpu_wspipeline_pass(num_stages, num_warps, capability)
|
||||
pm.add_tritongpu_wsmutex_pass(capability)
|
||||
pm.add_tritongpu_wsmaterialization_pass(capability)
|
||||
pm.add_cse_pass()
|
||||
else:
|
||||
<<<<<<< HEAD
|
||||
if is_hip():
|
||||
pm.add_tritongpu_pipeline_pass(
|
||||
num_stages, num_warps, num_ctas, 0)
|
||||
@@ -139,6 +195,11 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch,
|
||||
else:
|
||||
pm.add_tritongpu_materialize_load_store_pass(num_warps, arch)
|
||||
if _is_cuda(arch) and arch // 10 <= 8:
|
||||
=======
|
||||
pm.add_tritongpu_pipeline_pass(num_stages, num_warps, num_ctas, capability)
|
||||
pm.add_tritongpu_materialize_load_store_pass(num_warps, capability)
|
||||
if capability // 10 <= 8:
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
pm.add_tritongpu_prefetch_pass()
|
||||
pm.add_tritongpu_optimize_dot_operands_pass()
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
@@ -148,7 +209,11 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch,
|
||||
pm.add_tritongpu_reorder_instructions_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_symbol_dce_pass()
|
||||
<<<<<<< HEAD
|
||||
if _is_cuda(arch) and arch // 10 >= 9:
|
||||
=======
|
||||
if capability // 10 >= 9:
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
pm.add_tritongpu_fence_insertion_pass()
|
||||
pm.add_tritongpu_ws_fixup_missing_attrs_pass()
|
||||
pm.run(mod)
|
||||
@@ -162,12 +227,21 @@ def _add_external_libs(mod, libs):
|
||||
add_external_libs(mod, list(libs.keys()), list(libs.values()))
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
def ttgir_to_llir(mod, extern_libs, arch, tma_infos, waves_per_eu=0):
|
||||
if extern_libs:
|
||||
_add_external_libs(mod, extern_libs)
|
||||
# TODO: separate tritongpu_to_llvmir for different backends
|
||||
if _is_cuda(arch):
|
||||
return translate_triton_gpu_to_llvmir(mod, arch, tma_infos, runtime.TARGET.NVVM, waves_per_eu)
|
||||
=======
|
||||
def ttgir_to_llir(mod, extern_libs, target, tma_infos):
|
||||
if extern_libs:
|
||||
_add_external_libs(mod, extern_libs)
|
||||
# TODO: separate tritongpu_to_llvmir for different backends
|
||||
if _is_cuda(target):
|
||||
return translate_triton_gpu_to_llvmir(mod, target.capability, tma_infos, runtime.TARGET.NVVM)
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
else:
|
||||
return translate_triton_gpu_to_llvmir(mod, 0, TMAInfos(), runtime.TARGET.ROCDL, waves_per_eu)
|
||||
|
||||
@@ -190,7 +264,7 @@ def ptx_get_version(cuda_version) -> int:
|
||||
raise RuntimeError("Triton only support CUDA 10.0 or higher")
|
||||
|
||||
|
||||
def llir_to_ptx(mod: Any, arch: int, ptx_version: int = None) -> str:
|
||||
def llir_to_ptx(mod: Any, target: CudaTargetDescriptor, ptx_version: int = None) -> str:
|
||||
'''
|
||||
Translate TritonGPU module to PTX code.
|
||||
:param mod: a TritonGPU dialect module
|
||||
@@ -199,10 +273,10 @@ def llir_to_ptx(mod: Any, arch: int, ptx_version: int = None) -> str:
|
||||
if ptx_version is None:
|
||||
_, cuda_version = path_to_ptxas()
|
||||
ptx_version = ptx_get_version(cuda_version)
|
||||
return translate_llvmir_to_ptx(mod, arch, ptx_version)
|
||||
return translate_llvmir_to_ptx(mod, target.capability, ptx_version)
|
||||
|
||||
|
||||
def ptx_to_cubin(ptx: str, arch: int):
|
||||
def ptx_to_cubin(ptx: str, target: CudaTargetDescriptor):
|
||||
'''
|
||||
Compile TritonGPU module to cubin.
|
||||
:param ptx: ptx code
|
||||
@@ -210,7 +284,11 @@ def ptx_to_cubin(ptx: str, arch: int):
|
||||
:return: str
|
||||
'''
|
||||
ptxas, _ = path_to_ptxas()
|
||||
<<<<<<< HEAD
|
||||
return compile_ptx_to_cubin(ptx, ptxas, arch)
|
||||
=======
|
||||
return compile_ptx_to_cubin(ptx, ptxas, target.capability)
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
@@ -230,13 +308,15 @@ def get_kernel_name(src: str, pattern: str) -> str:
|
||||
|
||||
|
||||
def convert_type_repr(x):
|
||||
match = re.search(r'!tt\.ptr<(.*)>', x)
|
||||
# Currently we only capture the pointer type and assume the pointer is on global memory.
|
||||
# TODO: Capture and support shared memory space
|
||||
match = re.search(r'!tt\.ptr<([^,]+)', x)
|
||||
if match is not None:
|
||||
return '*' + convert_type_repr(match.group(1))
|
||||
return x
|
||||
|
||||
|
||||
def make_hash(fn, arch, env_vars, **kwargs):
|
||||
def make_hash(fn, target, env_vars, **kwargs):
|
||||
if isinstance(fn, JITFunction):
|
||||
configs = kwargs["configs"]
|
||||
signature = kwargs["signature"]
|
||||
@@ -253,9 +333,16 @@ def make_hash(fn, arch, env_vars, **kwargs):
|
||||
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1), sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8))
|
||||
configs_key = [get_conf_key(conf) for conf in configs]
|
||||
env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())]
|
||||
<<<<<<< HEAD
|
||||
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{matrix_instr_nonkdim}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{arch}-{env_vars_list}"
|
||||
=======
|
||||
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}"
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
return hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
assert isinstance(fn, str)
|
||||
ignore_version = kwargs.get('ignore_version', False)
|
||||
if (ignore_version):
|
||||
return hashlib.md5((Path(fn).read_text()).encode("utf-8")).hexdigest()
|
||||
return hashlib.md5((Path(fn).read_text() + version_key()).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
@@ -266,7 +353,8 @@ def make_hash(fn, arch, env_vars, **kwargs):
|
||||
# (letters, digits, or underscores), and capture it as group 1 (the function name)
|
||||
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
|
||||
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
|
||||
mlir_prototype_pattern = r'^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$'
|
||||
# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3
|
||||
mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$"
|
||||
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
|
||||
prototype_pattern = {
|
||||
"ttir": mlir_prototype_pattern,
|
||||
@@ -274,7 +362,11 @@ prototype_pattern = {
|
||||
"ptx": ptx_prototype_pattern,
|
||||
}
|
||||
|
||||
mlir_arg_type_pattern = r'%\w+: ([^,^\)\s]+)(?: \{\S+ = \S+ : \S+\})?,?'
|
||||
# - ((?:[^,\s<]+|<[^>]+>)+): Capturing group that matches one or more of either:
|
||||
# [^,\s<]+: One or more characters that are not a comma, whitespace, or the < symbol.
|
||||
# |: OR
|
||||
# <[^>]+>: A string that starts with < and ends with >, containing any characters except > in between.
|
||||
mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<]+|<[^>]+>)+),?'
|
||||
ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
|
||||
arg_type_pattern = {
|
||||
"ttir": mlir_arg_type_pattern,
|
||||
@@ -282,7 +374,11 @@ arg_type_pattern = {
|
||||
"ptx": ptx_arg_type_pattern,
|
||||
}
|
||||
if is_hip():
|
||||
<<<<<<< HEAD
|
||||
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
|
||||
=======
|
||||
ttgir_num_warps_pattern = r'"triton_gpu_rocm.num-warps"\s?=\s?(\d+)\s?:'
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
else:
|
||||
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
|
||||
|
||||
@@ -311,6 +407,7 @@ def parse_mlir_module(path, context):
|
||||
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"], defaults=[set(), set(), set(), set()])
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
# TODO: architecture descriptor class
|
||||
def _is_cuda(arch):
|
||||
return isinstance(arch, int)
|
||||
@@ -337,6 +434,15 @@ def get_architecture_descriptor(capability):
|
||||
capability = get_device_capability(device)
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
return capability
|
||||
=======
|
||||
def get_cuda_capability(capability):
|
||||
if capability is None:
|
||||
device = get_current_device()
|
||||
capability = get_device_capability(device)
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
return capability
|
||||
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
|
||||
@functools.lru_cache
|
||||
def get_arch_default_num_warps(device_type):
|
||||
@@ -347,15 +453,19 @@ def get_arch_default_num_warps(device_type):
|
||||
assert _device_backend
|
||||
arch = _device_backend.get_architecture_descriptor()
|
||||
num_warps = arch["num_warps"]
|
||||
|
||||
return num_warps
|
||||
|
||||
@functools.lru_cache
|
||||
def get_arch_default_num_stages(device_type, capability=None):
|
||||
<<<<<<< HEAD
|
||||
if device_type in ["cuda"]:
|
||||
arch = get_architecture_descriptor(capability)
|
||||
is_cuda = device_type == "cuda" and _is_cuda(arch)
|
||||
num_stages = 3 if is_cuda and arch >= 75 else 2
|
||||
=======
|
||||
if device_type == "cuda":
|
||||
num_stages = 3 if get_cuda_capability(capability) >= 75 else 2
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
else:
|
||||
_device_backend = get_backend(device_type)
|
||||
assert _device_backend
|
||||
@@ -365,11 +475,16 @@ def get_arch_default_num_stages(device_type, capability=None):
|
||||
return num_stages
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
def add_cuda_stages(arch, extern_libs, stages):
|
||||
=======
|
||||
def add_cuda_stages(target, extern_libs, stages):
|
||||
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
stages["ptx"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: llir_to_ptx(src, arch))
|
||||
lambda src: llir_to_ptx(src, target))
|
||||
stages["cubin"] = (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ptx_to_cubin(src, arch))
|
||||
lambda src: ptx_to_cubin(src, target))
|
||||
|
||||
|
||||
def compile(fn, **kwargs):
|
||||
@@ -379,6 +494,7 @@ def compile(fn, **kwargs):
|
||||
|
||||
if is_hip():
|
||||
device_type = "hip"
|
||||
<<<<<<< HEAD
|
||||
capability = None
|
||||
|
||||
if device_type == "cuda":
|
||||
@@ -393,6 +509,12 @@ def compile(fn, **kwargs):
|
||||
if is_hip():
|
||||
is_cuda = False
|
||||
warp_size = CUDA_DEFAULT_WARP_SIZE if _is_cuda(arch) else arch["warp_size"]
|
||||
=======
|
||||
is_cuda = device_type == "cuda"
|
||||
if is_hip():
|
||||
is_cuda = False
|
||||
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
context = ir.context()
|
||||
constants = kwargs.get("constants", dict())
|
||||
num_warps = kwargs.get("num_warps", get_arch_default_num_warps(device_type))
|
||||
@@ -420,11 +542,23 @@ def compile(fn, **kwargs):
|
||||
cluster_info.clusterDimY = kwargs["clusterDims"][1]
|
||||
cluster_info.clusterDimZ = kwargs["clusterDims"][2]
|
||||
tma_infos = TMAInfos()
|
||||
<<<<<<< HEAD
|
||||
|
||||
=======
|
||||
# build architecture descriptor
|
||||
if device_type == "cuda":
|
||||
_device_backend = get_backend(device_type)
|
||||
target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps)
|
||||
else:
|
||||
_device_backend = get_backend(device_type)
|
||||
assert _device_backend
|
||||
target = _device_backend.get_architecture_descriptor(**kwargs)
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
# build compilation stages
|
||||
stages = dict()
|
||||
stages["ast"] = (lambda path: fn, None)
|
||||
stages["ttir"] = (lambda path: parse_mlir_module(path, context),
|
||||
<<<<<<< HEAD
|
||||
lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch))
|
||||
if is_cuda:
|
||||
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
|
||||
@@ -456,12 +590,23 @@ def compile(fn, **kwargs):
|
||||
stages["llir"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos))
|
||||
_device_backend.add_stages(arch, extern_libs, stages)
|
||||
=======
|
||||
lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target))
|
||||
if is_cuda:
|
||||
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
|
||||
stages["llir"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, target, tma_infos))
|
||||
add_cuda_stages(target, extern_libs, stages)
|
||||
elif device_type == "hip":
|
||||
_device_backend.add_stages(target, extern_libs, stages, num_warps=num_warps, num_stages=num_stages)
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
else:
|
||||
# pass the user's configuration to the backend device.
|
||||
arch["num_warps"] = num_warps
|
||||
arch["num_stages"] = num_stages
|
||||
arch["num_ctas"] = num_ctas
|
||||
_device_backend.add_stages(arch, extern_libs, stages)
|
||||
target["num_warps"] = num_warps
|
||||
target["num_stages"] = num_stages
|
||||
target["num_ctas"] = num_ctas
|
||||
_device_backend.add_stages(target, extern_libs, stages)
|
||||
|
||||
# find out the signature of the function
|
||||
if isinstance(fn, JITFunction):
|
||||
@@ -482,6 +627,7 @@ def compile(fn, **kwargs):
|
||||
src = Path(fn).read_text()
|
||||
import re
|
||||
match = re.search(prototype_pattern[ir_name], src, re.MULTILINE)
|
||||
# TODO: support function attributes at group 3 (e.g., device function)
|
||||
name, signature = match.group(1), match.group(2)
|
||||
types = re.findall(arg_type_pattern[ir_name], signature)
|
||||
if ir_name == 'ttgir':
|
||||
@@ -494,7 +640,12 @@ def compile(fn, **kwargs):
|
||||
first_stage = list(stages.keys()).index(ir_name)
|
||||
|
||||
# create cache manager
|
||||
fn_cache_manager = get_cache_manager(make_hash(fn, arch, get_env_vars(), **kwargs))
|
||||
fn_cache_manager = get_cache_manager(make_hash(fn, target, get_env_vars(), **kwargs))
|
||||
# managers used to dump and override IR for debugging
|
||||
enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1"
|
||||
fn_override_manager = get_override_manager(make_hash(fn, target, get_env_vars(), **kwargs, ignore_version=True))
|
||||
fn_dump_manager = get_dump_manager(make_hash(fn, target, get_env_vars(), **kwargs, ignore_version=True))
|
||||
|
||||
# determine name and extension type of provided function
|
||||
if isinstance(fn, JITFunction):
|
||||
name, ext = fn.__name__, "ast"
|
||||
@@ -529,7 +680,7 @@ def compile(fn, **kwargs):
|
||||
"enable_persistent": enable_persistent,
|
||||
"constants": _get_jsonable_constants(constants),
|
||||
"debug": debug,
|
||||
"arch": arch, }
|
||||
"target": target, }
|
||||
metadata.update(get_env_vars())
|
||||
if ext == "ptx":
|
||||
assert "shared" in kwargs, "ptx compilation must provide shared memory size"
|
||||
@@ -539,7 +690,7 @@ def compile(fn, **kwargs):
|
||||
metadata["device_type"] = device_type
|
||||
|
||||
first_stage = list(stages.keys()).index(ext)
|
||||
asm = dict()
|
||||
asm = LazyDict()
|
||||
module = fn
|
||||
# run compilation pipeline and populate metadata
|
||||
for ir_name, (parse, compile_kernel) in list(stages.items())[first_stage:]:
|
||||
@@ -557,7 +708,11 @@ def compile(fn, **kwargs):
|
||||
metadata_group[extra_file_name] = fn_cache_manager.put(next_module[1], extra_file_name)
|
||||
else:
|
||||
metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
|
||||
fn_cache_manager.put(next_module, ir_filename)
|
||||
fn_dump_manager.put(next_module, ir_filename)
|
||||
if (enable_override and fn_override_manager.has_file(ir_filename)):
|
||||
print(f"\nOverriding kernel with file {ir_filename}")
|
||||
full_name = fn_override_manager.get_file(ir_filename)
|
||||
next_module = parse(full_name)
|
||||
else:
|
||||
if ir_name == "amdgcn":
|
||||
extra_file_name = f"{name}.hsaco_path"
|
||||
@@ -569,6 +724,7 @@ def compile(fn, **kwargs):
|
||||
|
||||
if ir_name == "cubin":
|
||||
asm[ir_name] = next_module
|
||||
asm["sass"] = lambda: get_sass(next_module)
|
||||
elif ir_name == "amdgcn":
|
||||
asm[ir_name] = str(next_module[0])
|
||||
else:
|
||||
@@ -579,11 +735,19 @@ def compile(fn, **kwargs):
|
||||
else:
|
||||
metadata["shared"] = get_shared_memory_size(module)
|
||||
if ir_name == "ttgir":
|
||||
<<<<<<< HEAD
|
||||
metadata["enable_warp_specialization"] = ir.is_ws_supported(next_module)
|
||||
if metadata["enable_warp_specialization"]:
|
||||
if is_hip():
|
||||
metadata["num_warps"] = _device_backend.get_num_warps(next_module)
|
||||
else:
|
||||
=======
|
||||
if is_hip():
|
||||
metadata["num_warps"] = _device_backend.get_num_warps(next_module)
|
||||
else:
|
||||
metadata["enable_warp_specialization"] = ir.is_ws_supported(next_module)
|
||||
if metadata["enable_warp_specialization"]:
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
metadata["num_warps"] = get_num_warps(next_module)
|
||||
if ir_name == "ptx":
|
||||
metadata["name"] = get_kernel_name(next_module, pattern='// .globl')
|
||||
@@ -723,16 +887,3 @@ class CompiledKernel:
|
||||
self.clusterDims[1], self.clusterDims[2], self.shared, stream, self.cu_function,
|
||||
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand)
|
||||
return runner
|
||||
|
||||
def get_sass(self, fun=None):
|
||||
if 'sass' in self.asm:
|
||||
return self.asm['sass']
|
||||
fd, path = tempfile.mkstemp()
|
||||
try:
|
||||
with open(fd, 'wb') as cubin:
|
||||
cubin.write(self.asm['cubin'])
|
||||
self.sass = extract(path, fun)
|
||||
finally:
|
||||
os.remove(path)
|
||||
self.asm['sass'] = self.sass
|
||||
return self.sass
|
||||
|
||||
@@ -63,8 +63,9 @@ def ty_to_cpp(ty):
|
||||
|
||||
|
||||
def generate_launcher(constants, signature, ids):
|
||||
start_desc = len(signature)
|
||||
signature = generate_cu_signature(constants, signature, ids)
|
||||
# Record the end of regular arguments;
|
||||
# subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
|
||||
signature, desc_start_idx = generate_cu_signature(constants, signature, ids)
|
||||
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
|
||||
|
||||
def _extracted_type(ty):
|
||||
@@ -99,7 +100,11 @@ def generate_launcher(constants, signature, ids):
|
||||
|
||||
# generate glue code
|
||||
folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']]
|
||||
<<<<<<< HEAD
|
||||
params = [i for i in signature.keys() if i >= start_desc or (i not in constants and i not in folded_without_constexprs)]
|
||||
=======
|
||||
params = [i for i in signature.keys() if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs)]
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
src = f"""
|
||||
#include \"cuda.h\"
|
||||
#include <stdbool.h>
|
||||
@@ -116,7 +121,10 @@ static inline void gpuAssert(CUresult code, const char *file, int line)
|
||||
char err[1024] = {{0}};
|
||||
strcat(err, prefix);
|
||||
strcat(err, str);
|
||||
PyGILState_STATE gil_state;
|
||||
gil_state = PyGILState_Ensure();
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
PyGILState_Release(gil_state);
|
||||
}}
|
||||
}}
|
||||
|
||||
@@ -251,6 +259,9 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
_launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''});
|
||||
Py_END_ALLOW_THREADS;
|
||||
if (PyErr_Occurred()) {{
|
||||
return NULL;
|
||||
}}
|
||||
|
||||
if (launch_exit_hook != Py_None && !PyObject_CallObject(launch_exit_hook, args)) {{
|
||||
return NULL;
|
||||
|
||||
@@ -26,12 +26,11 @@ from ..runtime import driver
|
||||
|
||||
def generate_cu_signature(constants, signature, ids):
|
||||
# CUtensorMap*s are always the last arguments
|
||||
num_regular_signatures = max(signature.keys()) + 1 if len(signature) > 0 else 0
|
||||
if ids["ids_of_tensormaps"] is not None:
|
||||
signature = signature.copy()
|
||||
num_signature = len(signature)
|
||||
for i, _ in enumerate(ids["ids_of_tensormaps"]):
|
||||
signature[num_signature + i] = '*CUtensorMap'
|
||||
return signature
|
||||
signature[num_regular_signatures + i] = '*CUtensorMap'
|
||||
return signature, num_regular_signatures
|
||||
|
||||
|
||||
def dummy_tensormaps_info(n=2):
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
from typing import Tuple
|
||||
|
||||
import dataclasses
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ExecutionContext:
|
||||
program_id: Tuple[int]
|
||||
program_size: Tuple[int]
|
||||
@@ -1,171 +0,0 @@
|
||||
import itertools
|
||||
import random
|
||||
from typing import Tuple
|
||||
|
||||
from .. import language as tl
|
||||
# import .language.core as lcore
|
||||
from ..language import core as lcore
|
||||
from . import torch_wrapper
|
||||
from .core import ExecutionContext
|
||||
from .memory_map import MemoryMap
|
||||
from .tl_lang import (TritonLangProxy, WrappedTensor, _primitive_to_tensor,
|
||||
debugger_constexpr)
|
||||
|
||||
torch = torch_wrapper.torch
|
||||
tl_method_backup = {}
|
||||
|
||||
|
||||
def get_proxy_method(proxy, name):
|
||||
method = getattr(proxy, name)
|
||||
|
||||
def fun(*args, **kwarg):
|
||||
return method(*args, **kwarg)
|
||||
|
||||
return fun
|
||||
|
||||
|
||||
def attach_triton(module, proxy):
|
||||
method_list = [func for func in dir(TritonLangProxy) if func[0] != "_"]
|
||||
for name in method_list:
|
||||
if hasattr(module, name):
|
||||
attr = getattr(module, name)
|
||||
tl_method_backup[name] = attr
|
||||
if callable(attr):
|
||||
setattr(module, name, get_proxy_method(proxy, name))
|
||||
else:
|
||||
setattr(module, name, getattr(proxy, name))
|
||||
|
||||
|
||||
def detach_triton(module):
|
||||
for name, method in tl_method_backup.items():
|
||||
setattr(module, name, method)
|
||||
|
||||
|
||||
def program_ids_from_grid(grid: Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
# reverse the grid dimensions and generate the range for each dimension
|
||||
reversed_grid = reversed(grid)
|
||||
ranges_for_each_dimension = [range(dim) for dim in reversed_grid]
|
||||
|
||||
# gen all combinations
|
||||
index_combinations = list(itertools.product(*ranges_for_each_dimension))
|
||||
random.shuffle(index_combinations)
|
||||
|
||||
for index_combination in index_combinations:
|
||||
yield index_combination
|
||||
|
||||
|
||||
class DebuggerFunction:
|
||||
def __init__(self, func, grid=(1,)):
|
||||
self.func = func
|
||||
self.grid = grid
|
||||
|
||||
def _is_constexpr(self, name):
|
||||
return name in self.func.__annotations__ and self.func.__annotations__[name] is lcore.constexpr
|
||||
|
||||
def _get_constexpr(self):
|
||||
result = []
|
||||
for name, annotation in self.func.__annotations__.items():
|
||||
if annotation is lcore.constexpr:
|
||||
result.append(name)
|
||||
return result
|
||||
|
||||
def _assert_constexpr(self, **kwargs):
|
||||
constexp = self._get_constexpr()
|
||||
missing = [i for i in constexp if i not in kwargs.keys()]
|
||||
assert len(missing) == 0, f"You must specify constexpr {missing}"
|
||||
|
||||
def _get_grid(self, **kwargs):
|
||||
if callable(self.grid):
|
||||
return self.grid(kwargs)
|
||||
else:
|
||||
return self.grid
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self._assert_constexpr(**kwargs)
|
||||
|
||||
memory = MemoryMap()
|
||||
|
||||
def convert_arg(v):
|
||||
name, arg = v
|
||||
if torch.is_tensor(arg):
|
||||
ptr = memory.add_tensor(arg)
|
||||
return WrappedTensor(torch.tensor([ptr], dtype=torch.int64, device="cuda"))
|
||||
if self._is_constexpr(name):
|
||||
return debugger_constexpr(arg)
|
||||
return WrappedTensor(_primitive_to_tensor(arg))
|
||||
|
||||
new_args = tuple(map(convert_arg, zip(self.func.__code__.co_varnames, args)))
|
||||
new_kwargs = {k: convert_arg((k, v)) for (k, v) in kwargs.items() if k not in ["num_warps", "num_stages"]}
|
||||
|
||||
grid = self._get_grid(**kwargs)
|
||||
for program_id in program_ids_from_grid(grid):
|
||||
proxy = TritonLangProxy(memory, ExecutionContext(program_id, grid))
|
||||
attach_triton(tl, proxy)
|
||||
self.func(*new_args, **new_kwargs)
|
||||
detach_triton(tl)
|
||||
|
||||
|
||||
class GridSelector:
|
||||
"""
|
||||
Entry point of the debugger
|
||||
"""
|
||||
|
||||
def __init__(self, func):
|
||||
version = torch.__version__
|
||||
assert version[0] == "2", f"Triton Debugger only supports torch >= 2.0, using {version}"
|
||||
self.func = func
|
||||
|
||||
def __getitem__(self, grid):
|
||||
return DebuggerFunction(self.func, grid)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return DebuggerFunction(self.func)(*args, **kwargs)
|
||||
|
||||
|
||||
class AutotuneGridSelector:
|
||||
def __init__(self, func, autotune_params):
|
||||
self.func = func
|
||||
self.autotune_params = autotune_params
|
||||
|
||||
def __getitem__(self, grid):
|
||||
return AutotuneRunner(self.func, self.autotune_params, grid)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return AutotuneRunner(self.func, self.autotune_params)(*args, **kwargs)
|
||||
|
||||
|
||||
class AutotuneRunner:
|
||||
def __init__(self, func, autotune_params, grid=None):
|
||||
self.func = func
|
||||
self.autotune_params = autotune_params
|
||||
self.grid = grid
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
assert len(self.autotune_params["configs"]) >= 1
|
||||
|
||||
for config in self.autotune_params["configs"][1:]:
|
||||
|
||||
def convert_arg(v):
|
||||
if torch.is_tensor(v):
|
||||
return torch.clone(v)
|
||||
return v
|
||||
|
||||
new_args = tuple(map(convert_arg, args))
|
||||
new_kwargs = {k: convert_arg(v) for k, v in kwargs.items()}
|
||||
if self.grid:
|
||||
self.func[self.grid](*new_args, **new_kwargs, **config.kwargs)
|
||||
else:
|
||||
self.func(*new_args, **new_kwargs, **config.kwargs)
|
||||
|
||||
main_config = self.autotune_params["configs"][0]
|
||||
if self.grid:
|
||||
self.func[self.grid](*args, **kwargs, **main_config.kwargs)
|
||||
else:
|
||||
self.func(*args, **kwargs, **main_config.kwargs)
|
||||
|
||||
|
||||
def triton_debug_autotune(**kwars):
|
||||
def wrapper(func):
|
||||
return AutotuneGridSelector(func, kwars)
|
||||
|
||||
return wrapper
|
||||
@@ -1,102 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
|
||||
from . import torch_wrapper
|
||||
|
||||
torch = torch_wrapper.torch
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RegisteredStorage:
|
||||
storage: torch.Storage
|
||||
dtype: torch.dtype
|
||||
size: int
|
||||
ptr: int
|
||||
|
||||
@property
|
||||
def end_ptr(self) -> int:
|
||||
return self.ptr + self.size
|
||||
|
||||
@property
|
||||
def access_tensor(self) -> torch.Tensor:
|
||||
return torch.tensor(self.storage, dtype=self.dtype, device=self.storage.device)
|
||||
|
||||
def ensure_immutable(self):
|
||||
assert self.storage.data_ptr() == self.ptr and self.storage.size() == self.size
|
||||
|
||||
|
||||
class MemoryMap:
|
||||
storages: [RegisteredStorage]
|
||||
|
||||
def __init__(self):
|
||||
self.storages = []
|
||||
|
||||
def _get_registered_storage(self, pointer: torch.Tensor):
|
||||
max_pointer = torch.max(pointer).item()
|
||||
min_pointer = torch.min(pointer).item()
|
||||
|
||||
registered_storage = next(
|
||||
filter(
|
||||
lambda registered: min_pointer >= registered.ptr and max_pointer < registered.end_ptr, self.storages
|
||||
),
|
||||
None,
|
||||
)
|
||||
if registered_storage is None:
|
||||
raise Exception("Storage not found or pointers spanning multiple tensors")
|
||||
registered_storage.ensure_immutable()
|
||||
return registered_storage
|
||||
|
||||
def add_tensor(self, t: torch.Tensor):
|
||||
storage = t.untyped_storage()
|
||||
self.storages.append(RegisteredStorage(storage, t.dtype, storage.size(), storage.data_ptr()))
|
||||
return t.data_ptr()
|
||||
|
||||
def load(
|
||||
self,
|
||||
pointer: torch.Tensor,
|
||||
mask: torch.Tensor = None,
|
||||
other=0.0,
|
||||
):
|
||||
assert pointer.is_cuda
|
||||
assert 0 < pointer.dim() < 3
|
||||
assert pointer.dtype == torch.int64
|
||||
|
||||
if mask is None:
|
||||
mask = torch.ones_like(pointer).bool()
|
||||
assert mask.is_cuda
|
||||
assert 0 < mask.dim() < 3
|
||||
assert mask.dtype == torch.bool
|
||||
mask = mask.expand(pointer.size())
|
||||
|
||||
if torch.all(~mask):
|
||||
# Todo: The type is wrong here, we can't determine the correct type
|
||||
return torch.full_like(pointer, fill_value=other, dtype=torch.float16, device="cuda")
|
||||
|
||||
registered_storage = self._get_registered_storage(pointer[mask])
|
||||
access_tensor = registered_storage.access_tensor
|
||||
|
||||
index_tensor = pointer - registered_storage.ptr
|
||||
|
||||
block = torch.full_like(pointer, fill_value=other, dtype=access_tensor.dtype, device="cuda")
|
||||
block[mask] = access_tensor[index_tensor[mask]]
|
||||
return block
|
||||
|
||||
def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None):
|
||||
assert 0 < pointer.dim() < 3
|
||||
assert pointer.dtype == torch.int64
|
||||
|
||||
if mask is None:
|
||||
mask = torch.ones_like(pointer).bool()
|
||||
assert 0 < mask.dim() < 3
|
||||
assert mask.dtype == torch.bool
|
||||
mask = mask.expand(pointer.size())
|
||||
|
||||
if torch.all(~mask):
|
||||
return
|
||||
|
||||
registered_storage = self._get_registered_storage(pointer[mask])
|
||||
access_tensor = registered_storage.access_tensor
|
||||
|
||||
index_tensor = pointer - registered_storage.ptr
|
||||
access_tensor[index_tensor[mask]] = value[mask].to(access_tensor.dtype)
|
||||
@@ -1,641 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ..language import core as lcore
|
||||
from . import torch_wrapper
|
||||
from .core import ExecutionContext
|
||||
from .memory_map import MemoryMap
|
||||
|
||||
torch = torch_wrapper.torch
|
||||
|
||||
|
||||
def _primitive_to_tensor(x):
|
||||
"""
|
||||
Converts various Python primitive data types to PyTorch tensor.
|
||||
"""
|
||||
tensor_args = {"device": "cuda"}
|
||||
if isinstance(x, bool):
|
||||
return torch.tensor([x], dtype=torch.bool, **tensor_args)
|
||||
elif isinstance(x, int):
|
||||
if -(2**31) <= x < 2**31:
|
||||
return torch.tensor([x], dtype=torch.int32, **tensor_args)
|
||||
elif -(2**63) <= x < 2**63:
|
||||
return torch.tensor([x], dtype=torch.int64, **tensor_args)
|
||||
else:
|
||||
raise RuntimeError(f"Nonrepresentable integer {x}.")
|
||||
elif isinstance(x, float):
|
||||
return torch.tensor([x], dtype=torch.float32, **tensor_args)
|
||||
elif torch.is_tensor(x):
|
||||
return x
|
||||
elif isinstance(x, WrappedTensor):
|
||||
return x
|
||||
elif isinstance(x, debugger_constexpr):
|
||||
if x.value is None:
|
||||
return None
|
||||
return _primitive_to_tensor(x.value)
|
||||
elif x is None:
|
||||
return None
|
||||
assert False, f"cannot convert {x} of type {type(x)} to tensor"
|
||||
|
||||
|
||||
def _infer_tensor(func):
|
||||
"""
|
||||
A decorator function to harmonize function args:
|
||||
- converts primitives to PyTorch tensors
|
||||
- wraps PyTorch tensors with WrappedTensors
|
||||
"""
|
||||
def wrapper(*args):
|
||||
new_args = tuple(map(lambda v: _primitive_to_tensor(v), args))
|
||||
new_args = tuple(map(lambda v: WrappedTensor(v) if torch.is_tensor(v) else v, new_args))
|
||||
|
||||
return func(*new_args)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _tensor_operation(func):
|
||||
"""
|
||||
A decorator function to unwrap WrappedTensors and debugger_constexpr before calling the function.
|
||||
Can be combined with _infer_tensor decorator to harmonize args (everything to torch tensor).
|
||||
"""
|
||||
def wrapper(*args, **kwargs):
|
||||
for arg in args:
|
||||
assert not torch.is_tensor(arg), "unexpected tensor argument"
|
||||
|
||||
def unwrap_tensor(v):
|
||||
if isinstance(v, WrappedTensor):
|
||||
return v.tensor
|
||||
if isinstance(v, debugger_constexpr):
|
||||
return v.value
|
||||
return v
|
||||
|
||||
new_args = tuple(map(unwrap_tensor, args))
|
||||
new_kwargs = {k: unwrap_tensor(v) for k, v in kwargs.items()}
|
||||
|
||||
result = func(args[0], *new_args[1:], **new_kwargs)
|
||||
return WrappedTensor(result) if torch.is_tensor(result) else result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class debugger_constexpr:
|
||||
def __init__(self, value):
|
||||
if isinstance(value, debugger_constexpr):
|
||||
self.value = value.value
|
||||
else:
|
||||
self.value = value
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "debugger_constexpr(" + str(self.value) + ")"
|
||||
|
||||
def __index__(self) -> int:
|
||||
return self.value
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.value)
|
||||
|
||||
def __ge__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value >= other
|
||||
|
||||
def __gt__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value > other
|
||||
|
||||
def __le__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value <= other
|
||||
|
||||
def __lt__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value < other
|
||||
|
||||
def __eq__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value == other
|
||||
|
||||
def __or__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value | other
|
||||
|
||||
def __ror__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value | other
|
||||
|
||||
def __and__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value & other
|
||||
|
||||
def __rand__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value & other
|
||||
|
||||
def to(self, dtype, bitcast=False, _builder=None):
|
||||
if dtype in [torch.int64]:
|
||||
ret_ty = int
|
||||
elif dtype == torch.bool:
|
||||
ret_ty = bool
|
||||
elif dtype in [torch.float64]:
|
||||
ret_ty = float
|
||||
else:
|
||||
raise ValueError("dtype not supported in debugger")
|
||||
return debugger_constexpr(ret_ty(self.value))
|
||||
|
||||
|
||||
class WrappedTensor:
|
||||
def __init__(self, tensor):
|
||||
self.tensor = tensor
|
||||
|
||||
def __index__(self) -> int:
|
||||
return self.tensor.item()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "wrapped_" + str(self.tensor)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return torch.all(self.tensor == True).item() # noqa: E712
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.tensor.dtype
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __add__(self, other):
|
||||
return torch.add(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __radd__(self, other):
|
||||
return self.__add__(other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __sub__(self, other):
|
||||
return torch.sub(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rsub__(self, other):
|
||||
return torch.sub(other, self.tensor)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __mul__(self, other):
|
||||
return torch.mul(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rmul__(self, other):
|
||||
return self.__mul__(other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __truediv__(self, other):
|
||||
return torch.div(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rtruediv__(self, other):
|
||||
return torch.div(other, self.tensor)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __floordiv__(self, other):
|
||||
return torch.floor_divide(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rfloordiv__(self, other):
|
||||
return torch.floor_divide(other, self.tensor)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __mod__(self, other):
|
||||
return torch.remainder(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rmod__(self, other):
|
||||
return torch.remainder(other, self.tensor)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __neg__(self):
|
||||
return -self.tensor
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __invert__(self):
|
||||
return ~self.tensor
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __and__(self, other):
|
||||
return torch.bitwise_and(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __or__(self, other):
|
||||
return torch.bitwise_or(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __xor__(self, other):
|
||||
return torch.bitwise_xor(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __lshift__(self, other):
|
||||
return torch.bitwise_left_shift(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rshift__(self, other):
|
||||
return torch.bitwise_right_shift(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __gt__(self, other):
|
||||
return self.tensor > other
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rgt__(self, other):
|
||||
return other > self.tensor
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __ge__(self, other):
|
||||
return self.tensor >= other
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rge__(self, other):
|
||||
return other >= self.tensor
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __lt__(self, other):
|
||||
return self.tensor < other
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rlt__(self, other):
|
||||
return other < self.tensor
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __le__(self, other):
|
||||
return self.tensor <= other
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rle__(self, other):
|
||||
return other <= self.tensor
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __eq__(self, other):
|
||||
return torch.equal(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __ne__(self, other):
|
||||
return not torch.equal(self.tensor, other)
|
||||
|
||||
@_tensor_operation
|
||||
def __getitem__(self, slices):
|
||||
return self.tensor.__getitem__(slices)
|
||||
# if isinstance(slices, slice):
|
||||
# slices = [slices]
|
||||
# src_shape = self.shape
|
||||
# dst_shape = []
|
||||
# curr = 0
|
||||
# for sl in slices:
|
||||
# if isinstance(sl, constexpr) and sl.value is None:
|
||||
# dst_shape.append(1)
|
||||
# elif sl == slice(None, None, None):
|
||||
# dst_shape.append(src_shape[curr].value)
|
||||
# curr += 1
|
||||
# ret = torch.reshape(self.tensor, dst_shape, )
|
||||
# return ret
|
||||
|
||||
@_tensor_operation
|
||||
def to(self, dtype, bitcast=False):
|
||||
return self.tensor.to(dtype)
|
||||
# if isinstance(bitcast, constexpr):
|
||||
# bitcast = bitcast.value
|
||||
# if bitcast:
|
||||
# return semantic.bitcast(self, dtype, )
|
||||
# return semantic.cast(self, dtype, )
|
||||
|
||||
|
||||
def _constexpr_to_value(v):
|
||||
if isinstance(v, debugger_constexpr):
|
||||
return v.value
|
||||
return v
|
||||
|
||||
|
||||
class TritonLangProxy:
|
||||
_memory_map: MemoryMap
|
||||
_context: ExecutionContext
|
||||
|
||||
def __init__(self, memory_map: MemoryMap, context: ExecutionContext):
|
||||
self._memory_map = memory_map
|
||||
self._context = context
|
||||
|
||||
# Types
|
||||
# Removed void, int1, float8, uint16, uint32, uint64, pi32_t
|
||||
|
||||
# constexpr = debugger_constexpr
|
||||
|
||||
# Program functions
|
||||
|
||||
@_tensor_operation
|
||||
def load(
|
||||
self,
|
||||
pointer: torch.Tensor,
|
||||
mask: torch.Tensor = None,
|
||||
other=0.0,
|
||||
cache_modifier="",
|
||||
eviction_policy="",
|
||||
volatile=False,
|
||||
):
|
||||
return self._memory_map.load(pointer, mask, other)
|
||||
|
||||
@_tensor_operation
|
||||
def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None):
|
||||
return self._memory_map.store(pointer, value, mask)
|
||||
|
||||
@_tensor_operation
|
||||
def program_id(self, axis):
|
||||
assert axis < len(self._context.program_id)
|
||||
return torch.tensor([self._context.program_id[axis]], dtype=torch.int32, device="cuda")
|
||||
|
||||
@_tensor_operation
|
||||
def num_programs(self, axis):
|
||||
assert axis < len(self._context.program_size)
|
||||
return torch.tensor([self._context.program_size[axis]], dtype=torch.int32, device="cuda")
|
||||
|
||||
@_tensor_operation
|
||||
def arange(self, start, end):
|
||||
return torch.arange(start=start, end=end, dtype=torch.int32, device="cuda")
|
||||
|
||||
@_tensor_operation
|
||||
def zeros(self, shape, dtype):
|
||||
for i, d in enumerate(shape):
|
||||
if not isinstance(d, debugger_constexpr):
|
||||
raise TypeError(f"Shape element {i} must have type `constexpr`")
|
||||
if not isinstance(d.value, int):
|
||||
raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
|
||||
shape = [x.value for x in shape]
|
||||
if isinstance(dtype, lcore.dtype):
|
||||
if dtype.is_fp32():
|
||||
dtype = torch.float32
|
||||
elif dtype.is_fp16():
|
||||
dtype = torch.float16
|
||||
elif dtype.is_bf16():
|
||||
dtype = torch.bfloat16
|
||||
elif dtype.is_int32():
|
||||
dtype = torch.int32
|
||||
elif dtype.is_int16():
|
||||
dtype = torch.int16
|
||||
elif dtype.is_int8():
|
||||
dtype = torch.int8
|
||||
else:
|
||||
raise TypeError(f"Unsupported dtype {dtype}")
|
||||
return torch.zeros(size=shape, dtype=dtype, device="cuda")
|
||||
|
||||
@_tensor_operation
|
||||
def dequantize(self, input, scale, shift, nbit, dst_ty=None):
|
||||
if dst_ty is None:
|
||||
dst_ty = torch.float16
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def broadcast(self, input, other):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def broadcast_to(self, input, shape):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def cat(self, input, shape):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def reshape(self, input, shape):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def dot(self, input, other, trans_a=False, trans_b=False, allow_tf32=True):
|
||||
assert input.dtype == other.dtype
|
||||
if trans_a:
|
||||
input = input.T
|
||||
if trans_b:
|
||||
other = other.T
|
||||
return torch.matmul(input=input, other=other)
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_cas(self, pointer, cmp, val):
|
||||
stored = self._memory_map.load(pointer, None, 0.0)
|
||||
if not isinstance(cmp, torch.Tensor):
|
||||
cmp = torch.tensor([cmp], dtype=stored.dtype, device="cuda")
|
||||
if not isinstance(val, torch.Tensor):
|
||||
val = torch.tensor([val], dtype=stored.dtype, device="cuda")
|
||||
if stored == cmp:
|
||||
self._memory_map.store(pointer, val, None)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_xchg(self, pointer, val, mask=None):
|
||||
if isinstance(val, int):
|
||||
val = torch.tensor([val], dtype=torch.int32, device="cuda")
|
||||
stored = self._memory_map.load(pointer, mask, 0.0)
|
||||
self._memory_map.store(pointer, val, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_add(self, pointer, val, mask=None):
|
||||
# arbitrary other value as it will masked during storing
|
||||
stored = self._memory_map.load(pointer, mask, 0.0)
|
||||
result = stored + val
|
||||
self._memory_map.store(pointer, result, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_max(self, pointer, val, mask=None):
|
||||
stored = self._memory_map.load(pointer, mask, 0.0)
|
||||
result = torch.maximum(stored, val)
|
||||
self._memory_map.store(pointer, result, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_min(self, pointer, val, mask=None):
|
||||
stored = self._memory_map.load(pointer, mask, 0.0)
|
||||
result = torch.minimum(stored, val)
|
||||
self._memory_map.store(pointer, result, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_and(self, pointer, val, mask=None):
|
||||
stored = self._memory_map.load(pointer, mask, 0)
|
||||
result = torch.bitwise_and(stored, val)
|
||||
self._memory_map.store(pointer, result, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_or(self, pointer, val, mask=None):
|
||||
stored = self._memory_map.load(pointer, mask, 0)
|
||||
result = torch.bitwise_or(stored, val)
|
||||
self._memory_map.store(pointer, result, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_xor(self, pointer, val, mask=None):
|
||||
stored = self._memory_map.load(pointer, mask, 0)
|
||||
result = torch.bitwise_xor(stored, val)
|
||||
self._memory_map.store(pointer, result, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def where(self, condition, x, y):
|
||||
condition = _primitive_to_tensor(condition)
|
||||
x = _primitive_to_tensor(x)
|
||||
y = _primitive_to_tensor(y)
|
||||
return torch.where(condition, x, y)
|
||||
|
||||
@_tensor_operation
|
||||
def umulhi(self, x, y):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def fdiv(self, x, y, ieee_rounding=False):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def exp(self, x):
|
||||
return torch.exp(x)
|
||||
|
||||
@_tensor_operation
|
||||
def log(self, x):
|
||||
return torch.log(x)
|
||||
|
||||
@_tensor_operation
|
||||
def cos(self, x):
|
||||
return torch.cos(x)
|
||||
|
||||
@_tensor_operation
|
||||
def sin(self, x):
|
||||
return torch.sin(x)
|
||||
|
||||
@_tensor_operation
|
||||
def sqrt(self, x):
|
||||
return torch.sqrt(x)
|
||||
|
||||
@_tensor_operation
|
||||
def globaltimer(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def clock(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def debug_barrier(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def multiple_of(self, input, values):
|
||||
return input
|
||||
|
||||
@_tensor_operation
|
||||
def max_contiguous(self, input, values):
|
||||
return input
|
||||
|
||||
@_tensor_operation
|
||||
def max_constancy(self, input, values):
|
||||
return input
|
||||
|
||||
@_tensor_operation
|
||||
def abs(self, x):
|
||||
return torch.abs(x)
|
||||
|
||||
@_tensor_operation
|
||||
def cdiv(self, x, div):
|
||||
return (x + div - 1) // div
|
||||
|
||||
@_tensor_operation
|
||||
def minimum(self, x, y):
|
||||
if isinstance(x, int):
|
||||
x = torch.tensor(x, device="cuda")
|
||||
if isinstance(y, int):
|
||||
y = torch.tensor(y, device="cuda")
|
||||
return torch.minimum(x, y)
|
||||
|
||||
@_tensor_operation
|
||||
def maximum(self, x, y):
|
||||
return torch.maximum(x, y)
|
||||
|
||||
@_tensor_operation
|
||||
def sigmoid(self, x):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def softmax(self, x, ieee_rounding=False):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def ravel(self, x):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def swizzle2d(self, i, j, size_i, size_j, size_g):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def zeros_like(self, input):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def max(self, input, axis=None):
|
||||
if axis is None:
|
||||
return torch.max(input)
|
||||
return torch.max(input, dim=axis).values
|
||||
|
||||
@_tensor_operation
|
||||
def argmax(self, input, axis):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def min(self, input, axis=None):
|
||||
if axis is None:
|
||||
return torch.min(input)
|
||||
return torch.min(input, dim=axis).values
|
||||
|
||||
@_tensor_operation
|
||||
def argmin(self, input, axis):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def sum(self, input, axis=None):
|
||||
if axis is None:
|
||||
return torch.sum(input)
|
||||
return torch.sum(input, dim=axis)
|
||||
|
||||
@_tensor_operation
|
||||
def xor_sum(self, input, axis):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def cumsum(self, input, axis=None):
|
||||
if axis is None:
|
||||
return torch.cumsum(input)
|
||||
return torch.cumsum(input, dim=axis)
|
||||
|
||||
@_tensor_operation
|
||||
def cumprod(self, input, axis=None):
|
||||
if axis is None:
|
||||
return torch.cumprod(input)
|
||||
return torch.cumprod(input, dim=axis)
|
||||
@@ -1,18 +0,0 @@
|
||||
try:
|
||||
import torch as _torch
|
||||
except ImportError:
|
||||
_torch = None
|
||||
|
||||
|
||||
class TorchWrapper:
|
||||
"""
|
||||
Helps in making torch an optional dependency
|
||||
"""
|
||||
|
||||
def __getattr__(self, name):
|
||||
if _torch is None:
|
||||
raise ImportError("Triton requires PyTorch to be installed")
|
||||
return getattr(_torch, name)
|
||||
|
||||
|
||||
torch = TorchWrapper()
|
||||
@@ -293,7 +293,7 @@ class dtype:
|
||||
return self.name
|
||||
|
||||
def __repr__(self):
|
||||
return f'triton.language.{self.name}'
|
||||
return f'triton.language.{str(self)}'
|
||||
|
||||
|
||||
class pointer_type(dtype):
|
||||
@@ -551,9 +551,7 @@ class tensor:
|
||||
# IR handle
|
||||
self.handle = handle
|
||||
# Block shape
|
||||
self.shape = (1, )
|
||||
if type.is_block():
|
||||
self.shape = type.shape
|
||||
self.shape = type.shape if type.is_block() else ()
|
||||
self.numel = 1
|
||||
for s in self.shape:
|
||||
self.numel *= s
|
||||
@@ -564,14 +562,15 @@ class tensor:
|
||||
self.shape = [constexpr(s) for s in self.shape]
|
||||
|
||||
def __str__(self) -> str:
|
||||
# ex. "float32[3,4]"
|
||||
return str(self.dtype) + '[' + ','.join(str(s) for s in self.shape) + ']'
|
||||
# ex. "float32[16, 32]"
|
||||
return str(self.dtype) + '[' + ', '.join(str(s) for s in self.shape) + ']'
|
||||
|
||||
@builtin
|
||||
def __add__(self, other, _builder=None):
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.add(self, other, _builder)
|
||||
|
||||
@builtin
|
||||
def __radd__(self, other, _builder=None):
|
||||
return self.__add__(other, _builder=_builder)
|
||||
|
||||
@@ -580,6 +579,7 @@ class tensor:
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.sub(self, other, _builder)
|
||||
|
||||
@builtin
|
||||
def __rsub__(self, other, _builder=None):
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.sub(other, self, _builder)
|
||||
@@ -589,6 +589,7 @@ class tensor:
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.mul(self, other, _builder)
|
||||
|
||||
@builtin
|
||||
def __rmul__(self, other, _builder=None):
|
||||
return self.__mul__(other, _builder=_builder)
|
||||
|
||||
@@ -597,6 +598,7 @@ class tensor:
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.truediv(self, other, _builder)
|
||||
|
||||
@builtin
|
||||
def __rtruediv__(self, other, _builder=None):
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.truediv(other, self, _builder)
|
||||
@@ -688,8 +690,6 @@ class tensor:
|
||||
else:
|
||||
return semantic.lshr(other, self, _builder)
|
||||
|
||||
# comparison operators
|
||||
|
||||
# >
|
||||
@builtin
|
||||
def __gt__(self, other, _builder=None):
|
||||
@@ -763,11 +763,11 @@ class tensor:
|
||||
|
||||
@builtin
|
||||
def __getitem__(self, slices, _builder=None):
|
||||
if isinstance(slices, slice):
|
||||
if isinstance(slices, (slice, constexpr)):
|
||||
slices = [slices]
|
||||
ret = self
|
||||
for dim, sl in enumerate(slices):
|
||||
if isinstance(sl, constexpr) and sl.value is None:
|
||||
if sl is None or isinstance(sl, constexpr) and sl.value is None:
|
||||
ret = semantic.expand_dims(ret, dim, _builder)
|
||||
elif isinstance(sl, slice) and sl.start is None and sl.stop is None and sl.step is None:
|
||||
pass
|
||||
@@ -852,6 +852,8 @@ def arange(start, end, _builder=None):
|
||||
def _shape_check_impl(shape):
|
||||
shape = _constexpr_to_value(shape)
|
||||
for i, d in enumerate(shape):
|
||||
if isinstance(d, int):
|
||||
d = constexpr(d)
|
||||
if not isinstance(d, constexpr):
|
||||
raise TypeError(f"Shape element {i} must have type `constexpr`")
|
||||
if not isinstance(d.value, int):
|
||||
@@ -930,6 +932,12 @@ def broadcast_to(input, shape, _builder=None):
|
||||
|
||||
@builtin
|
||||
def trans(input, _builder=None):
|
||||
"""
|
||||
Returns a transposed tensor.
|
||||
|
||||
:param input: The input tensor.
|
||||
:type input:
|
||||
"""
|
||||
return semantic.trans(input, _builder)
|
||||
|
||||
|
||||
@@ -968,6 +976,15 @@ def view(input, shape, _builder=None):
|
||||
|
||||
@builtin
|
||||
def reshape(input, shape, _builder=None):
|
||||
"""
|
||||
Returns a tensor with the same number of elements as input but with the
|
||||
provided shape.
|
||||
|
||||
:param input: The input tensor.
|
||||
:type input:
|
||||
:param shape: The new shape.
|
||||
:type shape: Tuple[int]
|
||||
"""
|
||||
shape = _shape_check_impl(shape)
|
||||
return semantic.reshape(input, shape, _builder)
|
||||
|
||||
@@ -1012,7 +1029,7 @@ def expand_dims(input, axis, _builder=None):
|
||||
|
||||
|
||||
@builtin
|
||||
def dot(input, other, allow_tf32=True, out_dtype=float32, _builder=None):
|
||||
def dot(input, other, acc=None, allow_tf32=True, max_num_imprecise_acc=None, out_dtype=float32, _builder=None):
|
||||
"""
|
||||
Returns the matrix product of two blocks.
|
||||
|
||||
@@ -1025,7 +1042,8 @@ def dot(input, other, allow_tf32=True, out_dtype=float32, _builder=None):
|
||||
"""
|
||||
allow_tf32 = _constexpr_to_value(allow_tf32)
|
||||
out_dtype = _constexpr_to_value(out_dtype)
|
||||
return semantic.dot(input, other, allow_tf32, out_dtype, _builder)
|
||||
max_num_imprecise_acc = _constexpr_to_value(max_num_imprecise_acc)
|
||||
return semantic.dot(input, other, acc, allow_tf32, max_num_imprecise_acc, out_dtype, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
@@ -1266,6 +1284,14 @@ def where(condition, x, y, _builder=None):
|
||||
|
||||
@builtin
|
||||
def umulhi(x, y, _builder=None):
|
||||
"""
|
||||
Returns the most significant 32 bits of the product of x and y.
|
||||
|
||||
:param x: the input tensor
|
||||
:type x: int32
|
||||
:param y: the input tensor
|
||||
:type y: int32
|
||||
"""
|
||||
x = _to_tensor(x, _builder)
|
||||
y = _to_tensor(y, _builder)
|
||||
return semantic.umulhi(x, y, _builder)
|
||||
@@ -1273,6 +1299,15 @@ def umulhi(x, y, _builder=None):
|
||||
|
||||
@builtin
|
||||
def fdiv(x, y, ieee_rounding=False, _builder=None):
|
||||
"""
|
||||
Returns a floating-point resultant tensor of dividing x by y.
|
||||
|
||||
:param x: the input numerator value.
|
||||
:param y: the input denominator value.
|
||||
:param ieee_rounding: To follow IEEE-754 floating point number
|
||||
rounding mechanism
|
||||
:type ieee_rounding: bool
|
||||
"""
|
||||
ieee_rounding = _constexpr_to_value(ieee_rounding)
|
||||
return semantic.fdiv(x, y, ieee_rounding, _builder)
|
||||
|
||||
|
||||
@@ -13,3 +13,8 @@ def smid(_builder=None):
|
||||
return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [],
|
||||
dtype=core.int32, is_pure=True,
|
||||
pack=1, _builder=_builder)
|
||||
|
||||
|
||||
@core.builtin
|
||||
def num_threads(_builder=None):
|
||||
return core.constexpr(_builder.target.num_warps * 32)
|
||||
|
||||
@@ -12,6 +12,13 @@ import re
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
# TODO: redundant code -- remove after 3P backend refactor
|
||||
|
||||
|
||||
def _is_cuda(target):
|
||||
from ..compiler.compiler import CudaTargetDescriptor
|
||||
return isinstance(target, CudaTargetDescriptor)
|
||||
|
||||
# Create custom exception that prints message "hello"
|
||||
|
||||
|
||||
@@ -28,10 +35,14 @@ class IncompatibleTypeErrorImpl(Exception):
|
||||
# ===----------------------------------------------------------------------===##
|
||||
|
||||
def program_id(axis: int, builder: ir.builder) -> tl.tensor:
|
||||
if axis not in (0, 1, 2):
|
||||
raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}")
|
||||
return tl.tensor(builder.create_get_program_id(axis), tl.int32)
|
||||
|
||||
|
||||
def num_programs(axis: int, builder: ir.builder) -> tl.tensor:
|
||||
if axis not in (0, 1, 2):
|
||||
raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}")
|
||||
return tl.tensor(builder.create_get_num_programs(axis), tl.int32)
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
@@ -131,6 +142,8 @@ def add(input: tl.tensor,
|
||||
input, other = binary_op_type_checking_impl(input, other, builder, True, True)
|
||||
input_scalar_ty = input.type.scalar
|
||||
other_scalar_ty = other.type.scalar
|
||||
if input_scalar_ty.is_ptr() and other_scalar_ty.is_ptr():
|
||||
raise ValueError("cannot add pointers together")
|
||||
|
||||
# offset + ptr
|
||||
# ptr + offset
|
||||
@@ -504,19 +517,18 @@ def full(shape: List[int], value, dtype: tl.dtype, builder: ir.builder) -> tl.te
|
||||
if isinstance(value, tl.tensor):
|
||||
assert value.numel.value == 1, "only accepts size-1 tensor"
|
||||
value = cast(value, dtype, builder)
|
||||
ret_ty = tl.block_type(value.dtype, shape)
|
||||
return tl.tensor(builder.create_splat(value.handle, shape), ret_ty)
|
||||
else:
|
||||
# scalar
|
||||
if dtype is None:
|
||||
raise ValueError("dtype must be specified when value is not a tensor")
|
||||
if value == 0:
|
||||
value = builder.get_null_value(dtype.to_ir(builder))
|
||||
else:
|
||||
get_value_fn = getattr(builder, f"get_{dtype.name}")
|
||||
value = get_value_fn(value)
|
||||
if dtype is None:
|
||||
raise ValueError("dtype must be specified when value is not a tensor")
|
||||
ret_ty = tl.block_type(dtype, shape)
|
||||
return tl.tensor(builder.create_splat(value, shape), ret_ty)
|
||||
value = tl.tensor(value, dtype)
|
||||
|
||||
return splat(value, shape, builder)
|
||||
|
||||
|
||||
|
||||
@@ -529,6 +541,13 @@ def ones(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
|
||||
# Shape Manipulation
|
||||
# ===----------------------------------------------------------------------===//
|
||||
|
||||
def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor:
|
||||
assert not value.type.is_block(), "Cannot splat a block tensor"
|
||||
if len(shape) == 0:
|
||||
return value
|
||||
ret_ty = tl.block_type(value.dtype, shape)
|
||||
return tl.tensor(builder.create_splat(value.handle, shape), ret_ty)
|
||||
|
||||
|
||||
def view(input: tl.tensor,
|
||||
dst_shape: List[int],
|
||||
@@ -553,8 +572,12 @@ def reshape(input: tl.tensor,
|
||||
|
||||
|
||||
def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
dst_shape = list(input.type.shape)
|
||||
dst_shape = [tl._constexpr_to_value(x) for x in input.shape]
|
||||
dst_shape.insert(axis, 1)
|
||||
|
||||
if not input.type.is_block():
|
||||
return splat(input, shape=dst_shape, builder=builder)
|
||||
|
||||
ret_ty = tl.block_type(input.type.scalar, dst_shape)
|
||||
return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty)
|
||||
|
||||
@@ -674,11 +697,6 @@ def bitcast(input: tl.tensor,
|
||||
dst_ty)
|
||||
|
||||
|
||||
# TODO: architecture descriptor class
|
||||
def _is_cuda(arch):
|
||||
return isinstance(arch, int)
|
||||
|
||||
|
||||
def cast(input: tl.tensor,
|
||||
dst_ty: tl.dtype,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
@@ -693,7 +711,7 @@ def cast(input: tl.tensor,
|
||||
src_sca_ty = src_ty.scalar
|
||||
dst_sca_ty = dst_ty.scalar
|
||||
|
||||
if _is_cuda(builder.arch) and builder.arch < 89 and \
|
||||
if _is_cuda(builder.target) and builder.target.capability < 89 and \
|
||||
(src_sca_ty.is_fp8e4nv() or dst_sca_ty.is_fp8e4nv()):
|
||||
assert False, "fp8e4nv data type is not supported on CUDA arch < 89"
|
||||
|
||||
@@ -1139,13 +1157,20 @@ def atomic_max(ptr: tl.tensor,
|
||||
# for float
|
||||
# return atomic_smax(i_ptr, i_val) if val >= 0
|
||||
# return atomic_umin(i_ptr, i_val) if val < 0
|
||||
i_val = bitcast(val, tl.int32, builder)
|
||||
i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder)
|
||||
pos = greater_equal(val, tl.tensor(builder.get_fp32(0), sca_ty), builder)
|
||||
neg = less_than(val, tl.tensor(builder.get_fp32(0), sca_ty), builder)
|
||||
if sca_ty not in {tl.float32, tl.float64}:
|
||||
raise TypeError(f"atomic_max not supported for dtype {sca_ty}")
|
||||
|
||||
itype = tl.int32 if sca_ty == tl.float32 else tl.float64
|
||||
zero = full([], 0.0, sca_ty, builder)
|
||||
|
||||
i_val = bitcast(val, itype, builder)
|
||||
i_ptr = bitcast(ptr, tl.pointer_type(itype, 1), builder)
|
||||
pos = greater_equal(val, zero, builder)
|
||||
neg = less_than(val, zero, builder)
|
||||
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle, sem), i_val.type)
|
||||
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle, sem), i_val.type)
|
||||
return where(pos, pos_ret, neg_ret, builder)
|
||||
ret = where(pos, pos_ret, neg_ret, builder)
|
||||
return bitcast(ret, sca_ty, builder)
|
||||
|
||||
|
||||
def atomic_min(ptr: tl.tensor,
|
||||
@@ -1175,10 +1200,16 @@ def atomic_min(ptr: tl.tensor,
|
||||
# for float
|
||||
# return atomic_smin(i_ptr, i_val) if val >= 0
|
||||
# return atomic_umax(i_ptr, i_val) if val < 0
|
||||
i_val = bitcast(val, tl.int32, builder)
|
||||
i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder)
|
||||
pos = greater_equal(val, tl.tensor(builder.get_fp32(0), sca_ty), builder)
|
||||
neg = less_than(val, tl.tensor(builder.get_fp32(0), sca_ty), builder)
|
||||
if sca_ty not in {tl.float32, tl.float64}:
|
||||
raise TypeError(f"atomic_min not supported for dtype {sca_ty}")
|
||||
|
||||
itype = tl.int32 if sca_ty == tl.float32 else tl.float64
|
||||
zero = full([], 0.0, sca_ty, builder)
|
||||
|
||||
i_val = bitcast(val, itype, builder)
|
||||
i_ptr = bitcast(ptr, tl.pointer_type(itype, 1), builder)
|
||||
pos = greater_equal(val, zero, builder)
|
||||
neg = less_than(val, zero, builder)
|
||||
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN,
|
||||
i_ptr.handle,
|
||||
i_val.handle,
|
||||
@@ -1191,7 +1222,8 @@ def atomic_min(ptr: tl.tensor,
|
||||
and_(mask, neg, builder).handle,
|
||||
sem),
|
||||
i_val.type)
|
||||
return where(pos, pos_ret, neg_ret, builder)
|
||||
ret = where(pos, pos_ret, neg_ret, builder)
|
||||
return bitcast(ret, sca_ty, builder)
|
||||
|
||||
|
||||
def atomic_add(ptr: tl.tensor,
|
||||
@@ -1302,11 +1334,27 @@ def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool:
|
||||
return False
|
||||
return True
|
||||
|
||||
def gpu_has_mfma() -> bool:
|
||||
if not is_hip():
|
||||
return False
|
||||
return True # mfma supported in ['gfx908', 'gfx90a']
|
||||
|
||||
|
||||
def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool:
|
||||
if not gpu_has_mfma():
|
||||
return False
|
||||
# TODO: Add check for configurations and types.
|
||||
return True
|
||||
|
||||
|
||||
def dot(lhs: tl.tensor,
|
||||
rhs: tl.tensor,
|
||||
acc: tl.tensor,
|
||||
allow_tf32: bool,
|
||||
max_num_imprecise_acc: int,
|
||||
out_dtype: tl.dtype,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
<<<<<<< HEAD
|
||||
def assert_dtypes_valid(lhs_dtype, rhs_dtype, arch):
|
||||
if is_hip():
|
||||
assert lhs.dtype == rhs.dtype or (lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp16()) or \
|
||||
@@ -1320,6 +1368,31 @@ def dot(lhs: tl.tensor,
|
||||
# Checks for cuda arch
|
||||
if arch < 90:
|
||||
assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv(), "Dot op does not support fp8e4nv on CUDA arch < 90"
|
||||
=======
|
||||
def assert_dtypes_valid(lhs_dtype, rhs_dtype, target):
|
||||
# Checks for non-cuda archs
|
||||
if not _is_cuda(target):
|
||||
assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!"
|
||||
return
|
||||
# Checks for cuda arch
|
||||
if target.capability < 90:
|
||||
assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv(), "Dot op does not support fp8e4nv on CUDA arch < 90"
|
||||
if lhs_dtype.is_fp8() and rhs_dtype.is_fp8():
|
||||
return
|
||||
assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!"
|
||||
else:
|
||||
assert not lhs_dtype.is_fp8e4b15() and not rhs_dtype.is_fp8e4b15(), "Dot op does not support fp8e4b15 on CUDA arch >= 90"
|
||||
assert not lhs_dtype.is_fp8e4b15x4() and not rhs_dtype.is_fp8e4b15x4(), "Dot op does not support fp8e4b15x4 on CUDA arch >= 90"
|
||||
if lhs_dtype.is_int() or rhs_dtype.is_int():
|
||||
assert lhs_dtype == rhs_dtype, f"Both operands must be same type. First operand ({lhs_dtype}) and second operand ({rhs_dtype})"
|
||||
assert lhs_dtype.is_int8() or lhs_dtype.is_uint8(), f"Both operands must be either int8 or uint8. Operand type ({lhs_dtype})"
|
||||
elif lhs_dtype.is_fp8() or rhs_dtype.is_fp8():
|
||||
assert lhs_dtype.is_fp8e4nv() or lhs_dtype.is_fp8e5(), f"Only supports fp8e4nv or fp8e5. First operand ({lhs_dtype})"
|
||||
assert rhs_dtype.is_fp8e4nv() or rhs_dtype.is_fp8e5(), f"Only supports fp8e4nv or fp8e5. Second operand ({rhs_dtype})"
|
||||
else:
|
||||
assert lhs_dtype.is_fp16() or lhs_dtype.is_bf16() or lhs_dtype.is_fp32() or lhs_dtype.is_int1(), f"Unsupported dtype {lhs_dtype}"
|
||||
assert rhs_dtype.is_fp16() or rhs_dtype.is_bf16() or rhs_dtype.is_fp32() or rhs_dtype.is_int1(), f"Unsupported dtype {rhs_dtype}"
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!"
|
||||
else:
|
||||
assert not lhs_dtype.is_fp8e4b15() and not rhs_dtype.is_fp8e4b15(), "Dot op does not support fp8e4b15 on CUDA arch >= 90"
|
||||
@@ -1339,8 +1412,12 @@ def dot(lhs: tl.tensor,
|
||||
assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!"
|
||||
return
|
||||
|
||||
<<<<<<< HEAD
|
||||
assert lhs.type.is_block() and rhs.type.is_block()
|
||||
assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.arch)
|
||||
=======
|
||||
assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.target)
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
|
||||
assert len(lhs.shape) == 2, f"First input shape ({lhs.shape}) is not two dimensional!"
|
||||
assert len(rhs.shape) == 2, f"Second input shape ({rhs.shape}) is not two dimensional!"
|
||||
@@ -1367,6 +1444,8 @@ def dot(lhs: tl.tensor,
|
||||
assert is_hip() or lhs.shape[1].value >= 32, "small blocks not supported!"
|
||||
_0 = builder.get_int32(0)
|
||||
ret_scalar_ty = tl.int32
|
||||
elif out_dtype.is_bf16():
|
||||
raise ValueError("out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`")
|
||||
elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16():
|
||||
_0 = builder.get_fp32(0)
|
||||
ret_scalar_ty = tl.float32
|
||||
@@ -1401,10 +1480,25 @@ def dot(lhs: tl.tensor,
|
||||
ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
|
||||
ret_ty)
|
||||
return cast(ret, ret_scalar_ty, builder)
|
||||
<<<<<<< HEAD
|
||||
|
||||
_0 = builder.create_splat(_0, [M, N])
|
||||
=======
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
ret_ty = tl.block_type(ret_scalar_ty, [M, N])
|
||||
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
|
||||
if acc is None:
|
||||
acc_handle = builder.create_splat(_0, [M, N])
|
||||
else:
|
||||
acc_handle = acc.handle
|
||||
assert acc.type == ret_ty
|
||||
|
||||
# max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90
|
||||
if not (_is_cuda(builder.target) and builder.target.capability == 90 and lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and ret_scalar_ty.is_fp32()):
|
||||
max_num_imprecise_acc = 0
|
||||
if max_num_imprecise_acc is None:
|
||||
max_num_imprecise_acc = 2**30
|
||||
|
||||
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, allow_tf32, max_num_imprecise_acc),
|
||||
ret_ty)
|
||||
|
||||
|
||||
@@ -1574,7 +1668,7 @@ def abs(x: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
|
||||
|
||||
def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:
|
||||
if len(x.shape) != len(values):
|
||||
if max(1, len(x.shape)) != len(values):
|
||||
raise ValueError("Shape of input to multiple_of does not match the length of values")
|
||||
x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context()))
|
||||
return x
|
||||
@@ -1614,6 +1708,8 @@ def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno:
|
||||
|
||||
|
||||
def _convert_elem_to_ir_value(builder, elem, require_i64):
|
||||
if isinstance(elem, int):
|
||||
elem = tl.constexpr(elem)
|
||||
if isinstance(elem, tl.constexpr):
|
||||
return builder.get_int64(elem.value) if require_i64 else builder.get_int32(elem.value)
|
||||
elif isinstance(elem, tl.tensor):
|
||||
|
||||
@@ -160,7 +160,7 @@ def max(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr
|
||||
else:
|
||||
return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_fast)
|
||||
else:
|
||||
if core.constexpr(input.dtype.primitive_bitwidth) < 32:
|
||||
if core.constexpr(input.dtype.primitive_bitwidth) < core.constexpr(32):
|
||||
if core.constexpr(input.dtype.is_floating()):
|
||||
input = input.to(core.float32)
|
||||
else:
|
||||
|
||||
@@ -21,9 +21,10 @@ def _fwd_kernel(
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, H, N_CTX,
|
||||
Z_H_N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
@@ -31,27 +32,21 @@ def _fwd_kernel(
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
qvk_offset = off_hz * stride_qh
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + qvk_offset,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
vk_offset = qvk_offset // stride_qm
|
||||
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K + qvk_offset,
|
||||
shape=(BLOCK_DMODEL, N_CTX),
|
||||
base=K,
|
||||
shape=(BLOCK_DMODEL, Z_H_N_CTX),
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, 0),
|
||||
offsets=(0, vk_offset),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
order=(0, 1)
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V + qvk_offset,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_vk, stride_vn),
|
||||
offsets=(0, 0),
|
||||
base=V,
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(vk_offset, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
@@ -68,7 +63,11 @@ def _fwd_kernel(
|
||||
# don't work as expected with `exp` in the loop
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
# load q: it will stay in SRAM throughout
|
||||
q = tl.load(Q_block_ptr)
|
||||
|
||||
offs_k = tl.arange(0, BLOCK_DMODEL)
|
||||
Q_ptrs = Q + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk
|
||||
q = tl.load(Q_ptrs)
|
||||
|
||||
q = (q * qk_scale).to(K.dtype.element_ty)
|
||||
lo = 0
|
||||
hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX
|
||||
@@ -86,8 +85,7 @@ def _fwd_kernel(
|
||||
alpha = tl.math.exp2(m_i - m_i_new)
|
||||
p = tl.math.exp2(qk - m_i_new[:, None])
|
||||
# -- scale and update acc --
|
||||
acc_scale = l_i * 0 + alpha # workaround some compiler bug
|
||||
acc *= acc_scale[:, None]
|
||||
acc *= alpha[:, None]
|
||||
acc += tl.dot(p.to(V.dtype.element_ty), v, allow_tf32=True)
|
||||
# -- update m_i and l_i --
|
||||
l_i = l_i * alpha + tl.sum(p, 1)
|
||||
@@ -101,13 +99,14 @@ def _fwd_kernel(
|
||||
tl.store(l_ptrs, m_i + tl.math.log2(l_i))
|
||||
# write back O
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=Out + qvk_offset,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
base=Out,
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
offsets=(vk_offset + start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
# O_ptrs = Out + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk
|
||||
tl.store(O_block_ptr, acc.to(K.dtype.element_ty))
|
||||
|
||||
|
||||
@@ -137,13 +136,14 @@ def _bwd_kernel_one_col_block(
|
||||
D,
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
Z, H, N_CTX,
|
||||
off_hz, start_n, num_block,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
SEQUENCE_PARALLEL: tl.constexpr,
|
||||
CAUSAL: tl.constexpr,
|
||||
MMA_V3: tl.constexpr
|
||||
):
|
||||
if SEQUENCE_PARALLEL:
|
||||
DQ += stride_dqa.to(tl.int64) * start_n
|
||||
@@ -159,7 +159,7 @@ def _bwd_kernel_one_col_block(
|
||||
# initialize pointers to value-like data
|
||||
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
||||
v_ptrs = V + (offs_n[:, None] * stride_vk + offs_k[None, :] * stride_vn)
|
||||
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk)
|
||||
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
# pointer to row-wise quantities in value-like data
|
||||
@@ -203,8 +203,11 @@ def _bwd_kernel_one_col_block(
|
||||
dq += tl.dot(ds, k, allow_tf32=True)
|
||||
tl.store(dq_ptrs, dq)
|
||||
elif SEQUENCE_PARALLEL:
|
||||
# dq = tl.dot(ds, k, allow_tf32=True)
|
||||
dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds), allow_tf32=True))
|
||||
if MMA_V3:
|
||||
dq = tl.dot(ds, k, allow_tf32=True)
|
||||
else:
|
||||
# not work with mma v3, becuase M % 64 != 0
|
||||
dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds), allow_tf32=True))
|
||||
tl.store(dq_ptrs, dq)
|
||||
|
||||
# increment pointers
|
||||
@@ -212,7 +215,7 @@ def _bwd_kernel_one_col_block(
|
||||
q_ptrs += BLOCK_M * stride_qm
|
||||
do_ptrs += BLOCK_M * stride_qm
|
||||
# write-back
|
||||
dv_ptrs = DV + (offs_n[:, None] * stride_vk + offs_k[None, :] * stride_vn)
|
||||
dv_ptrs = DV + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk)
|
||||
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
||||
tl.store(dv_ptrs, dv)
|
||||
tl.store(dk_ptrs, dk)
|
||||
@@ -228,12 +231,13 @@ def _bwd_kernel(
|
||||
D,
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
Z, H, N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
SEQUENCE_PARALLEL: tl.constexpr,
|
||||
CAUSAL: tl.constexpr,
|
||||
MMA_V3: tl.constexpr
|
||||
# fmt: on
|
||||
):
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
@@ -259,13 +263,14 @@ def _bwd_kernel(
|
||||
D,
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
Z, H, N_CTX,
|
||||
off_hz, start_n, num_block_n,
|
||||
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_N=BLOCK_N,
|
||||
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
|
||||
CAUSAL=CAUSAL,
|
||||
MMA_V3=MMA_V3
|
||||
)
|
||||
else:
|
||||
start_n = tl.program_id(1)
|
||||
@@ -276,13 +281,14 @@ def _bwd_kernel(
|
||||
D,
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
Z, H, N_CTX,
|
||||
off_hz, start_n, num_block_n,
|
||||
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_N=BLOCK_N,
|
||||
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
|
||||
CAUSAL=CAUSAL,
|
||||
MMA_V3=MMA_V3
|
||||
)
|
||||
|
||||
|
||||
@@ -317,6 +323,7 @@ class _attention(torch.autograd.Function):
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
q.shape[0] * q.shape[1] * q.shape[2],
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
|
||||
IS_CAUSAL=causal,
|
||||
num_warps=num_warps,
|
||||
@@ -332,6 +339,8 @@ class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
MMA_V3 = capability[0] >= 9
|
||||
BLOCK = 128
|
||||
q, k, v, o, L = ctx.saved_tensors
|
||||
sequence_parallel = ctx.sequence_parallel
|
||||
@@ -365,6 +374,7 @@ class _attention(torch.autograd.Function):
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
|
||||
SEQUENCE_PARALLEL=sequence_parallel,
|
||||
CAUSAL=ctx.causal,
|
||||
MMA_V3=MMA_V3,
|
||||
num_warps=8,
|
||||
num_stages=1,
|
||||
)
|
||||
|
||||
@@ -82,6 +82,7 @@ def _kernel(A, B, C, M, N, K,
|
||||
stride_cm, stride_cn,
|
||||
dot_out_dtype: tl.constexpr,
|
||||
allow_tf32: tl.constexpr,
|
||||
fp8_fast_accum: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr
|
||||
):
|
||||
@@ -118,7 +119,10 @@ def _kernel(A, B, C, M, N, K,
|
||||
if AB_DTYPE:
|
||||
a = a.to(C.dtype.element_ty)
|
||||
b = b.to(C.dtype.element_ty)
|
||||
acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
|
||||
if fp8_fast_accum:
|
||||
acc = tl.dot(a, b, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
|
||||
else:
|
||||
acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
|
||||
A += BLOCK_K * SPLIT_K * stride_ak
|
||||
B += BLOCK_K * SPLIT_K * stride_bk
|
||||
acc = acc.to(C.dtype.element_ty)
|
||||
@@ -140,7 +144,7 @@ class _matmul(torch.autograd.Function):
|
||||
_locks = {}
|
||||
|
||||
@staticmethod
|
||||
def _call(a, b, dot_out_dtype, allow_tf32):
|
||||
def _call(a, b, dot_out_dtype, allow_tf32, fp8_fast_accum):
|
||||
device = a.device
|
||||
# handle non-contiguous inputs if necessary
|
||||
if a.stride(0) > 1 and a.stride(1) > 1:
|
||||
@@ -155,6 +159,8 @@ class _matmul(torch.autograd.Function):
|
||||
if a.dtype in [tl.float8e4nv, tl.float8e4b15, tl.float8e5] or\
|
||||
b.dtype in [tl.float8e4nv, tl.float8e4b15, tl.float8e5]:
|
||||
c_dtype = torch.float16
|
||||
elif a.dtype in [torch.int8] or b.dtype in [torch.int8]:
|
||||
c_dtype = torch.int32
|
||||
else:
|
||||
c_dtype = get_higher_dtype(a.dtype, b.dtype)
|
||||
c = torch.empty((M, N), device=device, dtype=c_dtype)
|
||||
@@ -174,6 +180,8 @@ class _matmul(torch.autograd.Function):
|
||||
ab_dtype = True
|
||||
if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [tl.float8e4nv, tl.float8e5]:
|
||||
ab_dtype = False
|
||||
if a.dtype in [torch.int8] and b.dtype in [torch.int8]:
|
||||
ab_dtype = False
|
||||
# launch kernel
|
||||
grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||
_kernel[grid](a, b, c, M, N, K,
|
||||
@@ -182,12 +190,13 @@ class _matmul(torch.autograd.Function):
|
||||
c.stride(0), c.stride(1),
|
||||
dot_out_dtype=dot_out_dtype,
|
||||
allow_tf32=allow_tf32,
|
||||
fp8_fast_accum=fp8_fast_accum,
|
||||
GROUP_M=8, AB_DTYPE=ab_dtype)
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, dot_out_dtype=None, allow_tf32=True):
|
||||
return _matmul._call(a, b, dot_out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
|
||||
def forward(ctx, a, b, dot_out_dtype=None, allow_tf32=True, fp8_fast_accum=True):
|
||||
return _matmul._call(a, b, dot_out_dtype=dot_out_dtype, allow_tf32=allow_tf32, fp8_fast_accum=fp8_fast_accum)
|
||||
|
||||
|
||||
matmul = _matmul.apply
|
||||
|
||||
@@ -5,14 +5,16 @@ import torch
|
||||
from .. import cdiv
|
||||
from .._C.libtriton.triton import runtime
|
||||
from ..runtime import driver
|
||||
from ..testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops
|
||||
from ..testing import (get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops,
|
||||
nvsmi)
|
||||
|
||||
|
||||
def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
''' return compute throughput in TOPS '''
|
||||
total_warps = num_ctas * min(num_warps, 4)
|
||||
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
|
||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, backend, device)
|
||||
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
|
||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, cur_sm_clock, backend, device)
|
||||
return tflops
|
||||
|
||||
|
||||
@@ -20,7 +22,8 @@ def get_simd_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
''' return compute throughput in TOPS '''
|
||||
total_warps = num_ctas * min(num_warps, 4)
|
||||
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
|
||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, backend, device)
|
||||
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
|
||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, cur_sm_clock, backend, device)
|
||||
return tflops
|
||||
|
||||
|
||||
|
||||
@@ -132,9 +132,11 @@ class Autotuner(KernelInterface):
|
||||
full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs}
|
||||
if config.pre_hook is not None:
|
||||
config.pre_hook(full_nargs)
|
||||
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization, **kwargs, **config.kwargs)
|
||||
ret = self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization, **kwargs, **config.kwargs)
|
||||
self.nargs = None
|
||||
return ret
|
||||
|
||||
def prune_configs(self, kwargs):
|
||||
pruned_configs = self.configs
|
||||
|
||||
@@ -11,7 +11,10 @@ static inline void gpuAssert(CUresult code, const char *file, int line) {
|
||||
char err[1024] = {0};
|
||||
strcat(err, prefix);
|
||||
strcat(err, str);
|
||||
PyGILState_STATE gil_state;
|
||||
gil_state = PyGILState_Ensure();
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
PyGILState_Release(gil_state);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -327,7 +330,7 @@ static PyObject *memFree(PyObject *self, PyObject *args) {
|
||||
// Helper function to convert a Python list to a cuuint64_t array
|
||||
static cuuint64_t *list_to_cuuint64_array(PyObject *listObj) {
|
||||
Py_ssize_t len = PyList_Size(listObj);
|
||||
cuuint64_t *array = malloc(len * sizeof(cuuint64_t));
|
||||
cuuint64_t *array = (cuuint64_t *)malloc(len * sizeof(cuuint64_t));
|
||||
for (Py_ssize_t i = 0; i < len; i++) {
|
||||
PyObject *item = PyList_GetItem(listObj, i);
|
||||
array[i] = (cuuint64_t)PyLong_AsUnsignedLongLong(item);
|
||||
@@ -338,7 +341,7 @@ static cuuint64_t *list_to_cuuint64_array(PyObject *listObj) {
|
||||
// Helper function to convert a Python list to a cuuint32_t array
|
||||
static cuuint32_t *list_to_cuuint32_array(PyObject *listObj) {
|
||||
Py_ssize_t len = PyList_Size(listObj);
|
||||
cuuint32_t *array = malloc(len * sizeof(cuuint32_t));
|
||||
cuuint32_t *array = (cuuint32_t *)malloc(len * sizeof(cuuint32_t));
|
||||
for (Py_ssize_t i = 0; i < len; i++) {
|
||||
PyObject *item = PyList_GetItem(listObj, i);
|
||||
array[i] = (cuuint32_t)PyLong_AsUnsignedLong(item);
|
||||
|
||||
@@ -13,7 +13,10 @@ static inline void gpuAssert(hipError_t code, const char *file, int line) {
|
||||
const char *str = hipGetErrorString(code);
|
||||
char err[1024] = {0};
|
||||
snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str);
|
||||
PyGILState_STATE gil_state;
|
||||
gil_state = PyGILState_Ensure();
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
PyGILState_Release(gil_state);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,14 @@ def default_cache_dir():
|
||||
return os.path.join(Path.home(), ".triton", "cache")
|
||||
|
||||
|
||||
def default_override_dir():
|
||||
return os.path.join(Path.home(), ".triton", "override")
|
||||
|
||||
|
||||
def default_dump_dir():
|
||||
return os.path.join(Path.home(), ".triton", "dump")
|
||||
|
||||
|
||||
class CacheManager(ABC):
|
||||
def __init__(self, key):
|
||||
pass
|
||||
@@ -36,17 +44,26 @@ class CacheManager(ABC):
|
||||
|
||||
|
||||
class FileCacheManager(CacheManager):
|
||||
def __init__(self, key):
|
||||
def __init__(self, key, override=False, dump=False):
|
||||
self.key = key
|
||||
self.lock_path = None
|
||||
# create cache directory if it doesn't exist
|
||||
self.cache_dir = os.getenv('TRITON_CACHE_DIR', "").strip() or default_cache_dir()
|
||||
if self.cache_dir:
|
||||
if (dump):
|
||||
self.cache_dir = default_dump_dir()
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
elif (override):
|
||||
self.cache_dir = default_override_dir()
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
else:
|
||||
raise RuntimeError("Could not create or locate cache dir")
|
||||
# create cache directory if it doesn't exist
|
||||
self.cache_dir = os.getenv('TRITON_CACHE_DIR', "").strip() or default_cache_dir()
|
||||
if self.cache_dir:
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
else:
|
||||
raise RuntimeError("Could not create or locate cache dir")
|
||||
|
||||
def _make_path(self, filename) -> str:
|
||||
return os.path.join(self.cache_dir, filename)
|
||||
@@ -131,3 +148,11 @@ def get_cache_manager(key) -> CacheManager:
|
||||
__cache_cls_nme = user_cache_manager
|
||||
|
||||
return __cache_cls(key)
|
||||
|
||||
|
||||
def get_override_manager(key) -> CacheManager:
|
||||
return __cache_cls(key, override=True)
|
||||
|
||||
|
||||
def get_dump_manager(key) -> CacheManager:
|
||||
return __cache_cls(key, dump=True)
|
||||
|
||||
527
python/triton/runtime/interpreter.py
Normal file
527
python/triton/runtime/interpreter.py
Normal file
@@ -0,0 +1,527 @@
|
||||
import inspect
|
||||
|
||||
import numpy as np
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from .._C.libtriton.triton import interpreter as _interpreter
|
||||
|
||||
|
||||
# TODO: duplicate
|
||||
def str_to_ty(name):
|
||||
language = tl
|
||||
if name[0] == "*":
|
||||
ty = str_to_ty(name[1:])
|
||||
return language.pointer_type(ty)
|
||||
tys = {
|
||||
"fp8e4nv": language.float8e4nv,
|
||||
"fp8e5": language.float8e5,
|
||||
"fp8e4b15": language.float8e4b15,
|
||||
"fp8e4b15x4": language.float8e4b15x4,
|
||||
"fp16": language.float16,
|
||||
"bf16": language.bfloat16,
|
||||
"fp32": language.float32,
|
||||
"fp64": language.float64,
|
||||
"i1": language.int1,
|
||||
"i8": language.int8,
|
||||
"i16": language.int16,
|
||||
"i32": language.int32,
|
||||
"i64": language.int64,
|
||||
"u8": language.uint8,
|
||||
"u16": language.uint16,
|
||||
"u32": language.uint32,
|
||||
"u64": language.uint64,
|
||||
"B": language.int1,
|
||||
}
|
||||
return tys[name]
|
||||
|
||||
|
||||
class TensorHandle:
|
||||
|
||||
def __init__(self, data, dtype):
|
||||
self.data = data
|
||||
self.dtype = dtype
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.data.all())
|
||||
|
||||
|
||||
class BlockPointerHandle:
|
||||
|
||||
def __init__(self, base, shape, strides, offsets, tensor_shape, order):
|
||||
self.base = base
|
||||
self.shape = shape
|
||||
self.strides = strides
|
||||
self.offsets = offsets
|
||||
self.tensor_shape = tensor_shape
|
||||
self.order = order
|
||||
|
||||
def materialize_pointers(self, boundary_check):
|
||||
dtype_tt = self.base.dtype.element_ty
|
||||
n_bytes = dtype_tt.primitive_bitwidth // 8
|
||||
tensor_shape = self.tensor_shape
|
||||
ptrs = np.broadcast_to(self.base.data, self.tensor_shape)
|
||||
masks = np.ones(self.tensor_shape, dtype=bool)
|
||||
for dim in range(len(tensor_shape)):
|
||||
bcast_dims = [1] * len(tensor_shape)
|
||||
bcast_dims[dim] = tensor_shape[dim]
|
||||
off = (self.offsets[dim].data + np.arange(tensor_shape[dim])).reshape(bcast_dims)
|
||||
ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64)
|
||||
if dim in boundary_check:
|
||||
masks = np.logical_and(masks, off < self.shape[dim].data)
|
||||
ptrs = TensorHandle(ptrs, self.base.dtype)
|
||||
return ptrs, masks
|
||||
|
||||
|
||||
def wrap_ret(compute_ret_ty):
|
||||
def wrapper(fn):
|
||||
def wrapped(*args, **kwargs):
|
||||
ret = fn(*args, **kwargs)
|
||||
return TensorHandle(ret.data, compute_ret_ty(*args, **kwargs))
|
||||
return wrapped
|
||||
return wrapper
|
||||
|
||||
|
||||
class Builder:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.arch = None
|
||||
# pass
|
||||
|
||||
def set_grid_idx(self, x, y, z):
|
||||
assert x < self.grid_dim[0]
|
||||
assert y < self.grid_dim[1]
|
||||
assert z < self.grid_dim[2]
|
||||
self.grid_idx = (x, y, z)
|
||||
|
||||
def set_grid_dim(self, nx, ny, nz):
|
||||
self.grid_dim = (nx, ny, nz)
|
||||
|
||||
def np_dtype(self, tt_dtype):
|
||||
if isinstance(tt_dtype, tl.pointer_type):
|
||||
return np.dtype(np.uint64)
|
||||
np_types = {
|
||||
tl.float16: np.dtype(np.float16),
|
||||
tl.float32: np.dtype(np.float32),
|
||||
tl.float64: np.dtype(np.float64),
|
||||
tl.int8: np.dtype(np.int8),
|
||||
tl.uint8: np.dtype(np.uint8),
|
||||
tl.int16: np.dtype(np.int16),
|
||||
tl.uint16: np.dtype(np.uint16),
|
||||
tl.int32: np.dtype(np.int32),
|
||||
tl.uint32: np.dtype(np.uint32),
|
||||
tl.int64: np.dtype(np.int64),
|
||||
tl.uint64: np.dtype(np.uint64),
|
||||
}
|
||||
return np_types[tt_dtype]
|
||||
|
||||
# constants
|
||||
def get_half_ty(self):
|
||||
return tl.float16
|
||||
|
||||
def get_float_ty(self):
|
||||
return tl.float32
|
||||
|
||||
def get_int64_ty(self):
|
||||
return tl.int64
|
||||
|
||||
def get_ptr_ty(self, elt_ty, addr_space):
|
||||
return tl.pointer_type(elt_ty, addr_space)
|
||||
|
||||
def get_block_ty(self, dtype, shape):
|
||||
return tl.tensor(shape, dtype)
|
||||
|
||||
def get_int32(self, value):
|
||||
return TensorHandle(np.array([value], dtype=np.int32), tl.int32)
|
||||
|
||||
def get_int64(self, value):
|
||||
return TensorHandle(np.array([value], dtype=np.int64), tl.int64)
|
||||
|
||||
def get_fp16(self, value):
|
||||
return TensorHandle(np.array([value], dtype=np.float16), tl.float16)
|
||||
|
||||
def get_fp32(self, value):
|
||||
return TensorHandle(np.array([value], dtype=np.float32), tl.float32)
|
||||
|
||||
def get_null_value(self, type):
|
||||
return TensorHandle(np.array([0], dtype=self.np_dtype(type)), type)
|
||||
|
||||
# programming model
|
||||
def create_get_program_id(self, axis):
|
||||
assert self.grid_idx is not None
|
||||
return TensorHandle(np.array([self.grid_idx[axis]], dtype=np.int32), tl.int32)
|
||||
|
||||
def create_get_num_programs(self, axis):
|
||||
return TensorHandle(np.array([self.grid_dim[axis]], dtype=np.int32), tl.int32)
|
||||
|
||||
# memory ops
|
||||
def create_load(self, ptr, _0, _1, is_volatile):
|
||||
mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1)
|
||||
other = None
|
||||
return self.create_masked_load(ptr, mask, other, _0, _1, is_volatile)
|
||||
|
||||
def create_store(self, ptr, val, _0, _1):
|
||||
mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1)
|
||||
return self.create_masked_store(ptr, val, mask, None, None)
|
||||
|
||||
def create_masked_load(self, ptrs, mask, other, cache_modifier, eviction_policy, is_volatile):
|
||||
dtype_tt = ptrs.dtype.element_ty
|
||||
dtype_np = self.np_dtype(dtype_tt)
|
||||
if other is None:
|
||||
other = TensorHandle(np.ones_like(ptrs.data, dtype=dtype_np), dtype_tt)
|
||||
ret = _interpreter.load(ptrs.data, mask.data, other.data, dtype_np)
|
||||
return TensorHandle(ret, dtype_tt)
|
||||
|
||||
def create_masked_store(self, ptrs, value, mask, cache_modifier, eviction_policy):
|
||||
return _interpreter.store(ptrs.data, value.data, mask.data)
|
||||
|
||||
# casting ops
|
||||
def cast_impl(self, src, dst_type):
|
||||
if isinstance(dst_type, tl.tensor):
|
||||
dst_type = dst_type.dtype
|
||||
return TensorHandle(src.data.astype(self.np_dtype(dst_type)), dst_type)
|
||||
|
||||
create_si_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type)
|
||||
create_ui_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type)
|
||||
create_fp_to_si = lambda self, src, dst_type: self.cast_impl(src, dst_type)
|
||||
create_fp_to_ui = lambda self, src, dst_type: self.cast_impl(src, dst_type)
|
||||
create_fp_ext = lambda self, src, dst_type: self.cast_impl(src, dst_type)
|
||||
create_fp_trunc = lambda self, src, dst_type: self.cast_impl(src, dst_type)
|
||||
create_int_cast = lambda self, src, dst_type, is_signed: self.cast_impl(src, dst_type)
|
||||
|
||||
def create_fp_to_fp(self, src, dst_type):
|
||||
assert "float8 not NotImplemented yet"
|
||||
|
||||
def create_bitcast(self, src, dst_type):
|
||||
return TensorHandle(src.data.view(self.np_dtype(dst_type)), dst_type)
|
||||
|
||||
# binary operators
|
||||
def binary_op(self, lhs, rhs, op):
|
||||
return TensorHandle(op(lhs.data, rhs.data), lhs.dtype)
|
||||
|
||||
create_fadd = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add)
|
||||
create_fmul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply)
|
||||
create_fdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide)
|
||||
create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder)
|
||||
create_fsub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract)
|
||||
create_mul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply)
|
||||
create_sdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.floor_divide)
|
||||
create_udiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.floor_divide)
|
||||
create_srem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder)
|
||||
create_urem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder)
|
||||
create_add = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add)
|
||||
create_sub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract)
|
||||
create_shl = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.left_shift)
|
||||
create_lshr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift)
|
||||
create_ashr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift)
|
||||
create_minsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
|
||||
create_minui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
|
||||
create_minf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
|
||||
create_maxsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
|
||||
create_maxui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
|
||||
create_maxf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
|
||||
create_icmpSLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
|
||||
create_icmpSLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
|
||||
create_icmpSGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
|
||||
create_icmpSGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
|
||||
create_icmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
|
||||
create_icmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
|
||||
create_icmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
|
||||
create_icmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
|
||||
create_icmpEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal)
|
||||
create_icmpNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal)
|
||||
create_fcmpOLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
|
||||
create_fcmpOGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
|
||||
create_fcmpOLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
|
||||
create_fcmpOGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
|
||||
create_fcmpOEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal)
|
||||
create_fcmpONE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal)
|
||||
create_fcmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
|
||||
create_fcmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
|
||||
create_fcmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
|
||||
create_fcmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
|
||||
create_fcmpUEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal)
|
||||
create_fcmpUNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal)
|
||||
create_and = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_and)
|
||||
create_xor = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_xor)
|
||||
create_or = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_or)
|
||||
|
||||
# ternary functions
|
||||
def ternary_op(self, lhs, rhs, other, op):
|
||||
return TensorHandle(op(lhs.data, rhs.data, other.data), other.dtype)
|
||||
create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where)
|
||||
|
||||
# unary functions
|
||||
def unary_op(self, arg, op):
|
||||
return TensorHandle(op(arg.data), arg.dtype)
|
||||
create_exp = lambda self, arg: self.unary_op(arg, np.exp)
|
||||
create_cos = lambda self, arg: self.unary_op(arg, np.cos)
|
||||
create_sin = lambda self, arg: self.unary_op(arg, np.sin)
|
||||
create_log = lambda self, arg: self.unary_op(arg, np.log)
|
||||
create_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt)
|
||||
create_fabs = lambda self, arg: self.unary_op(arg, np.abs)
|
||||
create_iabs = lambda self, arg: self.unary_op(arg, np.abs)
|
||||
|
||||
# tensor operators
|
||||
create_dot = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.dot)
|
||||
create_view = lambda self, arg, shape: TensorHandle(arg.data.reshape(shape), arg.dtype)
|
||||
create_trans = lambda self, arg: self.unary_op(arg, np.transpose)
|
||||
|
||||
def create_dot(self, a, b, d, allow_tf32, maxNumImpreciseAcc):
|
||||
return TensorHandle(np.dot(a.data, b.data) + d.data, a.dtype)
|
||||
|
||||
def create_make_range(self, start, stop):
|
||||
return TensorHandle(np.arange(start, stop, dtype=np.int32), tl.int32)
|
||||
|
||||
# pointer arithmetic
|
||||
|
||||
def create_addptr(self, ptr, offset):
|
||||
dtype_tt = ptr.dtype.element_ty
|
||||
return TensorHandle(ptr.data + (dtype_tt.primitive_bitwidth // 8) * offset.data.astype(np.uint64), ptr.dtype)
|
||||
|
||||
def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, is_volatile):
|
||||
ptrs, masks = ptr.materialize_pointers(boundary_check)
|
||||
assert padding_option is None
|
||||
other = None
|
||||
return self.create_masked_load(ptrs, masks, other, cache_modifier, eviction_policy, is_volatile)
|
||||
|
||||
def create_tensor_pointer_store(self, ptr, value, boundary_check, cache_modifier, eviction_policy):
|
||||
ptrs, masks = ptr.materialize_pointers(boundary_check)
|
||||
return self.create_masked_store(ptrs, value, masks, cache_modifier, eviction_policy)
|
||||
|
||||
def create_expand_dims(self, arg, axis):
|
||||
return TensorHandle(np.expand_dims(arg.data, axis), arg.dtype)
|
||||
|
||||
def create_broadcast(self, arg, shape):
|
||||
return TensorHandle(np.broadcast_to(arg.data, shape), arg.dtype)
|
||||
|
||||
def create_int_to_ptr(self, val, dst_ty):
|
||||
return TensorHandle(val.data.astype(np.uint64), dst_ty)
|
||||
# def create_cat(self, lhs, rhs):
|
||||
# pass
|
||||
|
||||
# def create_broadcast(self, arg, shape):
|
||||
# pass
|
||||
|
||||
def create_splat(self, arg, shape):
|
||||
return TensorHandle(np.full(shape, arg.data[0], dtype=self.np_dtype(arg.dtype)), arg.dtype)
|
||||
|
||||
# def create_atomic_cas(self, ptr, cmp, val, sem):
|
||||
# pass
|
||||
|
||||
# def create_atomic_rmw(self, rmwOp, ptr, val, mask, sem):
|
||||
# pass
|
||||
|
||||
# def create_extern_elementwise(self, libName, libPath, symbol, argList, retType, isPure):
|
||||
# pass
|
||||
|
||||
# def create_reduce(self, operands, axis):
|
||||
# pass
|
||||
|
||||
# def create_reduce_ret(self, args):
|
||||
# pass
|
||||
|
||||
# def create_scan(self, operands, axis):
|
||||
# pass
|
||||
|
||||
# def create_scan_ret(self, args):
|
||||
# pass
|
||||
|
||||
# def create_ptr_to_int(self, val, type):
|
||||
# pass
|
||||
|
||||
# def create_int_to_ptr(self, val, type):
|
||||
# pass
|
||||
|
||||
# def create_inline_asm(self, inlineAsm, constraints, values, type, isPure, pack):
|
||||
# pass
|
||||
|
||||
# def create_print(self, prefix, values):
|
||||
# pass
|
||||
|
||||
# def create_assert(self, condition, message, fileName, funcName, lineNo):
|
||||
# pass
|
||||
|
||||
# def create_undef(self, type):
|
||||
# pass
|
||||
|
||||
# def create_barrier(self):
|
||||
# pass
|
||||
|
||||
def create_make_block_ptr(self, base, shape, strides, offsets, tensor_shape, order):
|
||||
return BlockPointerHandle(base, shape, strides, np.array(offsets), tensor_shape, order)
|
||||
|
||||
def create_advance(self, ptr, offsets):
|
||||
assert len(ptr.offsets) == len(offsets)
|
||||
ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, ptr.offsets, ptr.tensor_shape, ptr.order)
|
||||
for i in range(len(offsets)):
|
||||
ret.offsets[i].data += offsets[i].data
|
||||
return ret
|
||||
|
||||
|
||||
def patch_attr(obj, name, member, builder):
|
||||
new_member = lambda *args, member=member, **kwargs: (member(*args, **{k: v for k, v in kwargs.items() if k != '_builder'}, _builder=builder))
|
||||
setattr(obj, name, new_member)
|
||||
|
||||
|
||||
def _patch_lang_tensor(tensor, builder):
|
||||
for name, member in inspect.getmembers(tensor):
|
||||
if tl.core.is_builtin(member):
|
||||
patch_attr(tensor, name, member, builder)
|
||||
tensor.__index__ = lambda self: int(self.handle.data)
|
||||
tensor.__bool__ = lambda self: True
|
||||
tensor.__str__ = lambda self: str(self.handle.data)
|
||||
tensor.__getitem__ = lambda self, slices: self.handle.data.__getitem__(slices)
|
||||
|
||||
|
||||
def _patch_lang_core(lang, builder):
|
||||
for name, member in inspect.getmembers(lang):
|
||||
if tl.core.is_builtin(member):
|
||||
patch_attr(lang, name, member, builder)
|
||||
# reduce is better off with a separate patch due to how
|
||||
# the builder currently interfaces with custom functions
|
||||
|
||||
def _new_reduce(input, axis, combine_fn):
|
||||
fn = combine_fn.fn.__name__
|
||||
mapping = {
|
||||
'maximum': np.max,
|
||||
'_sum_combine': np.sum,
|
||||
}
|
||||
ret = mapping[fn](input.handle.data, axis=axis)
|
||||
ret_type = tl.block_type(input.dtype, ret.shape)
|
||||
return tl.core.tensor(TensorHandle(ret, input.dtype), ret_type)
|
||||
|
||||
lang.reduce = _new_reduce
|
||||
|
||||
|
||||
def _patch_lang_math(lang, builder):
|
||||
math = lang.math
|
||||
mapping = {
|
||||
'abs': 'abs',
|
||||
'acos': 'arccos',
|
||||
'asin': 'arcsin',
|
||||
'exp2': 'exp2',
|
||||
'log2': 'log2',
|
||||
'max': 'maximum',
|
||||
}
|
||||
|
||||
def make_numpy(name):
|
||||
def impl(*args, **kwargs):
|
||||
ret_type = args[0].type # TODO: incorrect
|
||||
ret_dtype = args[0].dtype # TODO: incorrect
|
||||
args = [arg.handle.data for arg in args]
|
||||
kwargs = {k: v.handle.data for k, v in kwargs.items()}
|
||||
ret = getattr(np, mapping[name])(*args, **kwargs)
|
||||
ret = tl.core.tensor(TensorHandle(ret, ret_dtype), ret_type)
|
||||
return ret
|
||||
return impl
|
||||
|
||||
def make_fallback(name):
|
||||
def fallback(*args, **kwargs):
|
||||
raise NotImplementedError(f"""
|
||||
{name} not supported in interpreter mode: no known numpy implementation.
|
||||
If you think that {name} in fact does have a numpy implementation, please add it
|
||||
to the mapping in python/triton/interpreter/new_interpreter.py:_patch_lang_math.
|
||||
""")
|
||||
return fallback
|
||||
|
||||
for name, member in inspect.getmembers(math):
|
||||
if name in mapping:
|
||||
setattr(math, name, make_numpy(name))
|
||||
else:
|
||||
setattr(math, name, make_fallback(name))
|
||||
|
||||
|
||||
# TODO: wrap everything in triton tensors
|
||||
def _implicit_cvt(arg):
|
||||
if isinstance(arg, int):
|
||||
ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
|
||||
handle = TensorHandle(np.array([arg], dtype=np.int32), ty)
|
||||
return tl.tensor(handle, ty)
|
||||
if hasattr(arg, 'data_ptr'):
|
||||
ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
|
||||
handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty)
|
||||
return tl.tensor(handle, ty)
|
||||
return arg
|
||||
|
||||
|
||||
def _unwrap(tensor):
|
||||
if isinstance(tensor, triton.TensorWrapper):
|
||||
return tensor.base
|
||||
return tensor
|
||||
|
||||
|
||||
builder = Builder()
|
||||
|
||||
RESERVED_KWS = ['num_warps', 'num_stages', 'num_ctas', 'enable_warp_specialization']
|
||||
|
||||
|
||||
class GridExecutor:
|
||||
|
||||
def __init__(self, fn, arg_names, grid):
|
||||
from .jit import _normalize_ty # TODO: modularize
|
||||
self.fn = fn
|
||||
self.arg_names = arg_names
|
||||
self.grid = grid
|
||||
__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
|
||||
self.constexprs = [name for name in arg_names if __annotations__.get(name) == 'constexpr']
|
||||
|
||||
def _patch_lang(self, builder):
|
||||
lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]]
|
||||
assert len(lang) == 1, "triton.language must be visible from within jit'd function"
|
||||
_patch_lang_tensor(getattr(lang[0], 'tensor'), builder)
|
||||
_patch_lang_core(lang[0], builder)
|
||||
_patch_lang_math(lang[0], builder)
|
||||
|
||||
def __call__(self, *args_dev, **kwargs):
|
||||
args_hst = [_unwrap(arg).cpu() if hasattr(arg, 'data_ptr') else arg for arg in args_dev]
|
||||
# removes reserved keywords from kwargs
|
||||
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS}
|
||||
# remaps core language functions to interpreted ones
|
||||
self._patch_lang(builder)
|
||||
# we need to copy arguments to the host for the interpreter
|
||||
# implicitly convert tensor arguments to their base pointers
|
||||
args = inspect.getcallargs(self.fn, *args_hst, **kwargs)
|
||||
args = {name: arg if name in self.constexprs else _implicit_cvt(arg) for name, arg in args.items()}
|
||||
# iterate through grid
|
||||
grid = self.grid(args) if callable(self.grid) else self.grid
|
||||
assert len(grid) <= 3
|
||||
grid = grid + (1,) * (3 - len(grid))
|
||||
builder.set_grid_dim(*grid)
|
||||
for x in range(grid[0]):
|
||||
for y in range(grid[1]):
|
||||
for z in range(grid[2]):
|
||||
builder.set_grid_idx(x, y, z)
|
||||
self.fn(**args)
|
||||
# copy arguments back to propagate side-effects
|
||||
for arg_dev, arg_hst in zip(args_dev, args_hst):
|
||||
if hasattr(arg_dev, 'data_ptr'):
|
||||
_unwrap(arg_dev).copy_(arg_hst.to(arg_dev.device))
|
||||
|
||||
|
||||
class InterpretedFunction:
|
||||
|
||||
def _patch_lang(self, builder):
|
||||
lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]]
|
||||
assert len(lang) == 1, "triton.language must be visible from within jit'd function"
|
||||
_patch_lang_tensor(getattr(lang[0], 'tensor'), builder)
|
||||
_patch_lang_core(lang[0], builder)
|
||||
|
||||
def __init__(self, fn) -> None:
|
||||
self.fn = fn
|
||||
|
||||
def run(*args, **kwargs):
|
||||
grid = kwargs['grid']
|
||||
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS + ['grid']}
|
||||
|
||||
return GridExecutor(self.fn, self.arg_names, grid)(*args, **kwargs)
|
||||
self.run = run
|
||||
signature = inspect.signature(fn)
|
||||
self.arg_names = [v.name for v in signature.parameters.values()]
|
||||
|
||||
def __getitem__(self, grid):
|
||||
return GridExecutor(self.fn, self.arg_names, grid)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self._patch_lang(builder)
|
||||
return self.fn(*args, **kwargs)
|
||||
@@ -14,6 +14,7 @@ from typing import (Callable, Generic, Iterable, List, Optional, TypeVar, Union,
|
||||
from .._C.libtriton.triton import TMAInfos
|
||||
from ..common.backend import get_backend, path_to_ptxas
|
||||
from ..language.core import dtype
|
||||
from .interpreter import InterpretedFunction
|
||||
|
||||
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
TRITON_VERSION = "2.1.0"
|
||||
@@ -256,6 +257,8 @@ class JITFunction(KernelInterface[T]):
|
||||
"float8_e5m2fnuz": "fp8e5b16",
|
||||
"float8e4b15": "fp8e4b15",
|
||||
"float8e4b15x4": "fp8e4b15x4",
|
||||
"float8_e4m3fn": "fp8e4nv",
|
||||
"float8_e5m2": "fp8e5",
|
||||
"float16": "fp16",
|
||||
"bfloat16": "bf16",
|
||||
"float32": "fp32",
|
||||
@@ -274,10 +277,6 @@ class JITFunction(KernelInterface[T]):
|
||||
tys[v] = v
|
||||
return key if isinstance(key, str) else f"*{tys[dtype_str]}"
|
||||
|
||||
def _make_signature(self, sig_key):
|
||||
signature = ",".join([self._type_of(k) for i, k in enumerate(sig_key)])
|
||||
return signature
|
||||
|
||||
def _make_constants(self, constexpr_key):
|
||||
constants = dict(zip(self.constexprs, constexpr_key))
|
||||
return constants
|
||||
@@ -304,29 +303,29 @@ class JITFunction(KernelInterface[T]):
|
||||
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={
|
||||
"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False)
|
||||
|
||||
def _get_arg_specialization_key(self, arg) -> str:
|
||||
arg_annotation = self.__annotations__.get(arg, '')
|
||||
def _get_arg_specialization_key(self, arg_name, arg):
|
||||
arg_annotation = self.__annotations__.get(arg_name, '')
|
||||
if arg_annotation == '':
|
||||
return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") \
|
||||
else ({arg} % {JITFunction.divisibility} == 0, {arg} % {JITFunction.divisibility_8} == 0, {arg} == 1) if isinstance({arg}, int) \
|
||||
else (False,)'
|
||||
return (arg.data_ptr() % JITFunction.divisibility == 0) if hasattr(arg, "data_ptr") \
|
||||
else (arg % JITFunction.divisibility == 0, arg % JITFunction.divisibility_8 == 0, arg == 1) if isinstance(arg, int) \
|
||||
else (False,)
|
||||
elif 'Tensor' in arg_annotation:
|
||||
return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0)'
|
||||
elif arg_annotation == 'int':
|
||||
return f'({arg} % {JITFunction.divisibility} == 0, {arg} % {JITFunction.divisibility_8} == 0, {arg} == 1)'
|
||||
return (arg.data_ptr() % JITFunction.divisibility == 0)
|
||||
elif 'int' in arg_annotation or 'bool' in arg_annotation:
|
||||
return (arg % JITFunction.divisibility == 0, arg % JITFunction.divisibility_8 == 0, arg == 1)
|
||||
else:
|
||||
return '(False,)'
|
||||
return (False,)
|
||||
|
||||
def _get_arg_sig_key(self, arg) -> str:
|
||||
arg_annotation = self.__annotations__.get(arg, '')
|
||||
def _get_arg_sig_key(self, arg_name, arg) -> str:
|
||||
arg_annotation = self.__annotations__.get(arg_name, '')
|
||||
if 'Tensor' in arg_annotation:
|
||||
return f'{arg}.dtype'
|
||||
return arg.dtype
|
||||
elif arg_annotation == 'bool':
|
||||
return "i1"
|
||||
elif arg_annotation == 'float':
|
||||
return 'fp32'
|
||||
else:
|
||||
return f'_key_of({arg})'
|
||||
return self._key_of(arg)
|
||||
|
||||
def _conclude_device_type(self, device_types: List[str], pinned_memory_flags: List[bool]) -> str:
|
||||
device_types = [device_type for device_type in device_types if device_type != '']
|
||||
@@ -344,32 +343,110 @@ class JITFunction(KernelInterface[T]):
|
||||
return device_types[0] if len(device_types) > 0 else 'cuda'
|
||||
|
||||
def _make_launcher(self):
|
||||
regular_args = [f'{arg}' for i, arg in enumerate(
|
||||
regular_args = [arg for i, arg in enumerate(
|
||||
self.arg_names) if i not in self.constexprs]
|
||||
constexpr_args = [
|
||||
f'{arg}' for i, arg in enumerate(
|
||||
self.arg_names) if i in self.constexprs]
|
||||
args = ', '.join(regular_args)
|
||||
# cache key for regular argument type
|
||||
sig_keys = ', '.join([self._get_arg_sig_key(arg) for arg in regular_args])
|
||||
device_types = '[' + ', '.join([f'_device_of({arg})' for arg in regular_args]) + ']'
|
||||
pinned_memory_flags = '[' + ', '.join([f'_pinned_memory_of({arg})' for arg in regular_args]) + ']'
|
||||
# cache key for constexpr argument values
|
||||
constexpr_keys = ', '.join(constexpr_args)
|
||||
# cache key for argument specialization
|
||||
specializations = []
|
||||
for i, arg in enumerate(regular_args):
|
||||
if i in self.do_not_specialize:
|
||||
continue
|
||||
specializations += [self._get_arg_specialization_key(arg)]
|
||||
constexpr_args = [arg for i, arg in enumerate(
|
||||
self.arg_names) if i in self.constexprs]
|
||||
|
||||
spec_keys = ', '.join(specializations)
|
||||
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
|
||||
def regular_args_v(args_proxy):
|
||||
return [args_proxy[arg_name] for arg_name in regular_args]
|
||||
|
||||
def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, stream, warmup, device, device_type):
|
||||
from ..compiler import (CompiledKernel, compile,
|
||||
get_arch_default_num_stages,
|
||||
get_arch_default_num_warps)
|
||||
sig_key = tuple([self._get_arg_sig_key(arg_name, args_proxy[arg_name]) for arg_name in regular_args])
|
||||
constexpr_key = tuple([args_proxy[arg_name] for arg_name in constexpr_args])
|
||||
specializations = []
|
||||
for i, arg_name in enumerate(regular_args):
|
||||
if i in self.do_not_specialize:
|
||||
continue
|
||||
specializations += [self._get_arg_specialization_key(arg_name, args_proxy[arg_name])]
|
||||
|
||||
spec_key = tuple(specializations)
|
||||
assert num_ctas > 0
|
||||
assert grid is not None
|
||||
if callable(grid):
|
||||
grid = grid(args_proxy)
|
||||
grid_size = len(grid)
|
||||
grid_0 = grid[0]
|
||||
grid_1 = grid[1] if grid_size > 1 else 1
|
||||
grid_2 = grid[2] if grid_size > 2 else 1
|
||||
if device_type is None:
|
||||
device_types = [self._device_of(arg) for arg in regular_args_v(args_proxy)]
|
||||
device_types = [_device_type for _device_type in device_types if _device_type != '']
|
||||
device_type = self._conclude_device_type(device_types, [self._pinned_memory_of(arg) for arg in
|
||||
regular_args_v(args_proxy)])
|
||||
|
||||
device_backend = None
|
||||
if device_type not in ['cuda']:
|
||||
device_backend = get_backend(device_type)
|
||||
if device_backend is None:
|
||||
raise ValueError('Cannot find backend for ' + device_type)
|
||||
|
||||
if device is None:
|
||||
if device_type in ['cuda']:
|
||||
device = get_current_device()
|
||||
set_current_device(device)
|
||||
else:
|
||||
device = device_backend.get_current_device()
|
||||
device_backend.set_current_device(device)
|
||||
if stream is None and not warmup:
|
||||
if device_type in ['cuda']:
|
||||
stream = get_cuda_stream(device)
|
||||
else:
|
||||
stream = device_backend.get_stream()
|
||||
|
||||
if num_warps is None:
|
||||
num_warps = get_arch_default_num_warps(device_type)
|
||||
if num_stages is None:
|
||||
num_stages = get_arch_default_num_stages(device_type)
|
||||
|
||||
key = (version_key(), sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, self.debug)
|
||||
if extern_libs is not None:
|
||||
key = (key, tuple(extern_libs.items()))
|
||||
|
||||
bin = self.cache[device].get(key, None)
|
||||
if bin is not None:
|
||||
# build dict of constant values
|
||||
args = regular_args_v(args_proxy)
|
||||
# Create tensormaps and append to args
|
||||
args = bin.assemble_tensormap_to_arg(args)
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
|
||||
return bin
|
||||
# kernel not cached -- compile
|
||||
else:
|
||||
# build dict of constant values
|
||||
args = regular_args_v(args_proxy)
|
||||
all_args = tuple([args_proxy[arg_name] for arg_name in self.arg_names])
|
||||
configs = self._get_config(*all_args),
|
||||
constants = self._make_constants(constexpr_key)
|
||||
constants.update({i: None for i, arg in enumerate(all_args) if arg is None})
|
||||
constants.update({i: 1 for i in configs[0].equal_to_1})
|
||||
# build kernel signature -- doesn't include specialized arguments
|
||||
signature = {i: self._type_of(self._key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs}
|
||||
# build stub signature -- includes arguments that are specialized
|
||||
for i, arg in constants.items():
|
||||
if callable(arg):
|
||||
raise TypeError(f"Callable constexpr at index {i} is not supported")
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, configs):
|
||||
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
|
||||
# Create tensormaps and append to args
|
||||
args = bin.assemble_tensormap_to_arg(args)
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
|
||||
self.cache[device][key] = bin
|
||||
return bin
|
||||
return None
|
||||
|
||||
# create a wrapper to call launcher_body
|
||||
args_map = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
|
||||
args_signature = ', '.join(name if dflt == inspect._empty else f'{name} = triton.language.dtype(\'{dflt}\')' if dtype.is_dtype(f'{dflt}') else f'{name} = {dflt}' for name, dflt in zip(self.arg_names, self.arg_defaults))
|
||||
args_signature = args_signature + ', ' if len(args_signature) > 0 else ''
|
||||
|
||||
src = f"""
|
||||
import triton
|
||||
<<<<<<< HEAD
|
||||
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, waves_per_eu=0, matrix_instr_nonkdim=0, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
from ..compiler import compile, CompiledKernel, get_arch_default_num_warps, get_arch_default_num_stages
|
||||
sig_key = {f'{sig_keys},' if len(sig_keys) > 0 else ()}
|
||||
@@ -449,19 +526,12 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
|
||||
self.cache[device][key] = bin
|
||||
return bin
|
||||
return None
|
||||
=======
|
||||
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
return launcher_body({{{args_map}}}, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, stream, warmup, device, device_type)
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
"""
|
||||
scope = {"version_key": version_key(),
|
||||
"get_cuda_stream": get_cuda_stream,
|
||||
"self": self,
|
||||
"_spec_of": self._spec_of,
|
||||
"_key_of": self._key_of,
|
||||
"_device_of": self._device_of,
|
||||
"_pinned_memory_of": self._pinned_memory_of,
|
||||
"cache": self.cache,
|
||||
"__spec__": __spec__,
|
||||
"get_backend": get_backend,
|
||||
"get_current_device": get_current_device,
|
||||
"set_current_device": set_current_device}
|
||||
scope = {"launcher_body": launcher_body}
|
||||
exec(src, scope)
|
||||
return scope[self.fn.__name__]
|
||||
|
||||
@@ -572,7 +642,6 @@ def jit(
|
||||
do_not_specialize: Optional[Iterable[int]] = None,
|
||||
debug: Optional[bool] = None,
|
||||
noinline: Optional[bool] = None,
|
||||
interpret: Optional[bool] = None,
|
||||
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
|
||||
"""
|
||||
Decorator for JIT-compiling a function using the Triton compiler.
|
||||
@@ -594,9 +663,8 @@ def jit(
|
||||
|
||||
def decorator(fn: T) -> JITFunction[T]:
|
||||
assert callable(fn)
|
||||
if interpret:
|
||||
from ..interpreter.interpreter import GridSelector
|
||||
return GridSelector(fn)
|
||||
if os.getenv("TRITON_INTERPRET", "0") == "1":
|
||||
return InterpretedFunction(fn)
|
||||
else:
|
||||
return JITFunction(
|
||||
fn,
|
||||
|
||||
@@ -32,8 +32,11 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None):
|
||||
"""
|
||||
if torch.cuda.current_stream() == torch.cuda.default_stream():
|
||||
raise RuntimeError("Cannot capture graph in default stream. Please use side stream in benchmark code.")
|
||||
# record CUDAGraph
|
||||
# warmup
|
||||
fn()
|
||||
# step 1 - we estimate the amount of time the kernel call takes
|
||||
# NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point
|
||||
# but it is probably good enough
|
||||
if grad_to_none is not None:
|
||||
for x in grad_to_none:
|
||||
x.detach_()
|
||||
@@ -43,39 +46,35 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None):
|
||||
with torch.cuda.graph(g):
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
fn = lambda: g.replay()
|
||||
# Estimate the runtime of the function
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event.record()
|
||||
fn()
|
||||
g.replay()
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
estimate_ms = start_event.elapsed_time(end_event)
|
||||
# compute number of repetition to last `rep` ms
|
||||
n_repeat = max(1, int(rep / estimate_ms))
|
||||
# compute number of repetition to last `rep` ms
|
||||
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
|
||||
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
|
||||
ret = []
|
||||
n_retries = 50
|
||||
for _ in range(n_retries):
|
||||
# Benchmark
|
||||
torch.cuda.synchronize()
|
||||
# step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize
|
||||
# host overhead
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
for i in range(n_repeat):
|
||||
# we don't want `fn` to accumulate gradient values
|
||||
# if it contains a backward pass. So we clear the
|
||||
# provided gradients
|
||||
if grad_to_none is not None:
|
||||
for x in grad_to_none:
|
||||
x.grad = None
|
||||
# record time of `fn`
|
||||
start_event[i].record()
|
||||
fn()
|
||||
end_event[i].record()
|
||||
torch.cuda.synchronize()
|
||||
# measure time and return
|
||||
ret = []
|
||||
n_retries = 10
|
||||
for i in range(n_retries):
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event.record()
|
||||
g.replay()
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)])
|
||||
ret.append(torch.min(times))
|
||||
ret += [start_event.elapsed_time(end_event) / n_repeat]
|
||||
return torch.mean(torch.tensor(ret)).item()
|
||||
|
||||
|
||||
@@ -266,7 +265,7 @@ class Mark:
|
||||
self.fn = fn
|
||||
self.benchmarks = benchmarks
|
||||
|
||||
def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool):
|
||||
def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, **kwrags):
|
||||
import os
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
@@ -287,7 +286,7 @@ class Mark:
|
||||
|
||||
row_mean, row_min, row_max = [], [], []
|
||||
for y in bench.line_vals:
|
||||
ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args)
|
||||
ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags)
|
||||
try:
|
||||
y_mean, y_min, y_max = ret
|
||||
except TypeError:
|
||||
@@ -328,14 +327,14 @@ class Mark:
|
||||
if save_path:
|
||||
df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False)
|
||||
|
||||
def run(self, show_plots=False, print_data=False, save_path=''):
|
||||
def run(self, show_plots=False, print_data=False, save_path='', **kwargs):
|
||||
has_single_bench = isinstance(self.benchmarks, Benchmark)
|
||||
benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks
|
||||
if save_path:
|
||||
html = open(os.path.join(save_path, "results.html"), "w")
|
||||
html.write("<html><body>\n")
|
||||
for bench in benchmarks:
|
||||
self._run(bench, save_path, show_plots, print_data)
|
||||
self._run(bench, save_path, show_plots, print_data, **kwargs)
|
||||
if save_path:
|
||||
html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
|
||||
if save_path:
|
||||
@@ -368,7 +367,7 @@ def get_dram_gbps(backend=None, device=None):
|
||||
return bw_gbps
|
||||
|
||||
|
||||
def get_max_tensorcore_tflops(dtype, backend=None, device=None, clock_rate=None):
|
||||
def get_max_tensorcore_tflops(dtype, clock_rate, backend=None, device=None):
|
||||
import torch
|
||||
|
||||
from .runtime import driver
|
||||
@@ -378,8 +377,6 @@ def get_max_tensorcore_tflops(dtype, backend=None, device=None, clock_rate=None)
|
||||
device = torch.cuda.current_device()
|
||||
|
||||
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4
|
||||
if not clock_rate:
|
||||
clock_rate = driver.utils.get_device_properties(device)["sm_clock_rate"] # in kHz
|
||||
capability = torch.cuda.get_device_capability(device)
|
||||
if capability[0] < 8:
|
||||
assert dtype == torch.float16
|
||||
@@ -423,21 +420,6 @@ def cuda_memcheck(**target_kwargs):
|
||||
return decorator
|
||||
|
||||
|
||||
def nvsmi_attr(attrs):
|
||||
attrs = ",".join(attrs)
|
||||
cmd = [
|
||||
"nvidia-smi",
|
||||
"-i",
|
||||
"0",
|
||||
"--query-gpu=" + attrs,
|
||||
"--format=csv,noheader,nounits",
|
||||
]
|
||||
out = subprocess.check_output(cmd)
|
||||
ret = out.decode(sys.stdout.encoding).split(",")
|
||||
ret = [int(x) for x in ret]
|
||||
return ret
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215):
|
||||
try:
|
||||
@@ -458,8 +440,8 @@ def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215):
|
||||
f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}",
|
||||
]
|
||||
)
|
||||
cur_sm_clock = nvsmi_attr(["clocks.current.sm"])[0]
|
||||
cur_mem_clock = nvsmi_attr(["clocks.current.memory"])[0]
|
||||
cur_sm_clock = nvsmi(["clocks.current.sm"])[0]
|
||||
cur_mem_clock = nvsmi(["clocks.current.memory"])[0]
|
||||
assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz"
|
||||
assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz"
|
||||
tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock
|
||||
@@ -471,7 +453,7 @@ def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215):
|
||||
subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"])
|
||||
|
||||
|
||||
def get_max_simd_tflops(dtype, backend=None, device=None):
|
||||
def get_max_simd_tflops(dtype, clock_rate, backend=None, device=None):
|
||||
import torch
|
||||
|
||||
from .runtime import driver
|
||||
@@ -481,7 +463,6 @@ def get_max_simd_tflops(dtype, backend=None, device=None):
|
||||
device = torch.cuda.current_device()
|
||||
|
||||
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4
|
||||
clock_rate = driver.utils.get_device_properties(device)["sm_clock_rate"] # in kHz
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 8:
|
||||
if dtype == torch.float32:
|
||||
|
||||
@@ -20,7 +20,7 @@ data along with utilities to load, unload and launch the kernel.
|
||||
signature is provided as a list of (optionally divisibility-hinted) types
|
||||
or constexpr values, e.g.
|
||||
|
||||
`compile.py --kernel-name kernel --signature "*f32:16, i32:16, 1024, i32" --out-name kernel /path/to/kernel.py`
|
||||
`compile.py --kernel-name kernel --signature "*fp32:16, i32:16, 1024, i32" --out-name kernel /path/to/kernel.py`
|
||||
|
||||
will compile triton.JITFunction of name `kernel` inside the file `/path/to/kernel.py`.
|
||||
Said kernel will be specialized such that argument 0, 1 are assumed to be multiple of 16,
|
||||
@@ -51,7 +51,7 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
out_name = args.out_name if args.out_name else args.kernel_name
|
||||
out_path = args.out_path if args.out_path else out_name
|
||||
out_path = args.out_path if args.out_path else Path(out_name)
|
||||
|
||||
# execute python sources and extract functions wrapped in JITFunction
|
||||
arg_path = Path(args.path)
|
||||
|
||||
@@ -20,8 +20,13 @@
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import functools
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
from ..common.backend import path_to_cuobjdump, path_to_nvdisasm
|
||||
|
||||
FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*')
|
||||
SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*')
|
||||
@@ -60,11 +65,26 @@ def processSassLines(fline, sline, labels):
|
||||
return (f'{ctrl}', f'{asm}')
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def get_sass(cubin_asm, fun=None):
|
||||
fd, path = tempfile.mkstemp()
|
||||
try:
|
||||
with open(fd, 'wb') as cubin:
|
||||
cubin.write(cubin_asm)
|
||||
sass = extract(path, fun)
|
||||
finally:
|
||||
os.remove(path)
|
||||
return sass
|
||||
|
||||
|
||||
def extract(file_path, fun):
|
||||
cuobjdump, _ = path_to_cuobjdump()
|
||||
nvdisasm, _ = path_to_nvdisasm()
|
||||
os.environ["NVDISASM_PATH"] = nvdisasm
|
||||
if fun is None:
|
||||
sass_str = subprocess.check_output(["cuobjdump", "-sass", file_path])
|
||||
sass_str = subprocess.check_output([cuobjdump, "-sass", file_path])
|
||||
else:
|
||||
sass_str = subprocess.check_output(["cuobjdump", "-fun", fun, "-sass", file_path])
|
||||
sass_str = subprocess.check_output([cuobjdump, "-fun", fun, "-sass", file_path])
|
||||
sass_lines = sass_str.splitlines()
|
||||
line_idx = 0
|
||||
while line_idx < len(sass_lines):
|
||||
|
||||
Reference in New Issue
Block a user