Merge pull request #410 from ROCmSoftwarePlatform/ifu-231117

Ifu 231117
This commit is contained in:
jayfurmanek
2023-12-15 09:09:40 -06:00
committed by GitHub
197 changed files with 9874 additions and 8267 deletions

View File

@@ -45,12 +45,12 @@ __all__ = [
"tools",
]
# -------------------------------------
# misc. utilities that don't fit well
# into any specific module
# -------------------------------------
def cdiv(x: int, y: int):
return (x + y - 1) // y

View File

@@ -1,5 +1,5 @@
import functools
import hashlib
import importlib
import importlib.util
import os
@@ -10,8 +10,12 @@ from typing import Dict
from ..runtime.driver import DriverBase
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
TRITON_VERSION = "2.1.0"
class BaseBackend:
def __init__(self, device_type: str) -> None:
self.device_type = device_type
@@ -104,7 +108,7 @@ def get_backend(device_type: str):
def _path_to_binary(binary: str):
base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
paths = [
os.environ.get("TRITON_PTXAS_PATH", ""),
os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
os.path.join(base_dir, "third_party", "cuda", "bin", binary)
]
@@ -132,3 +136,48 @@ def path_to_cuobjdump():
@functools.lru_cache()
def path_to_nvdisasm():
return _path_to_binary("nvdisasm")
@functools.lru_cache()
def compute_core_version_key():
import pkgutil
contents = []
# frontend
with open(__file__, "rb") as f:
contents += [hashlib.sha1(f.read()).hexdigest()]
# compiler
compiler_path = os.path.join(TRITON_PATH, 'compiler')
for lib in pkgutil.iter_modules([compiler_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.sha1(f.read()).hexdigest()]
# backend
libtriton_hash = hashlib.sha1()
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
while True:
chunk = f.read(1024**2)
if not chunk:
break
libtriton_hash.update(chunk)
contents.append(libtriton_hash.hexdigest())
# language
language_path = os.path.join(TRITON_PATH, 'language')
for lib in pkgutil.iter_modules([language_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.sha1(f.read()).hexdigest()]
return '-'.join(TRITON_VERSION) + '-'.join(contents)
_cached_cuda_version_key = None
def get_cuda_version_key():
global _cached_cuda_version_key
if _cached_cuda_version_key is None:
key = compute_core_version_key()
try:
ptxas = path_to_ptxas()[0]
ptxas_version = subprocess.check_output([ptxas, "--version"])
except RuntimeError:
ptxas_version = b"NO_PTXAS"
_cached_cuda_version_key = key + '-' + hashlib.sha1(ptxas_version).hexdigest()
return _cached_cuda_version_key

View File

@@ -92,9 +92,15 @@ def _build(name, src, srcdir):
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
if is_hip():
ret = subprocess.check_call([cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{hip_lib_dir}", "-lamdhip64", "-o", so])
ret = subprocess.check_call([
cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC",
f"-L{hip_lib_dir}", "-lamdhip64", "-o", so
])
else:
cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]
cc_cmd = [
cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda",
"-o", so
]
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
ret = subprocess.check_call(cc_cmd)

View File

@@ -1,5 +1,8 @@
from .compiler import (CompiledKernel, compile, get_arch_default_num_stages,
get_arch_default_num_warps, instance_descriptor)
from .compiler import (CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps,
instance_descriptor)
from .errors import CompilationError
__all__ = ["compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps", "get_arch_default_num_stages"]
__all__ = [
"compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps",
"get_arch_default_num_stages"
]

View File

@@ -10,8 +10,7 @@ from .._C.libtriton.triton import ir
from ..language import constexpr, tensor
# ideally we wouldn't need any runtime component
from ..runtime import JITFunction
from .errors import (CompilationError, CompileTimeAssertionFailure,
UnsupportedLanguageConstruct)
from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
def mangle_ty(ty):
@@ -68,7 +67,10 @@ def _check_fn_args(node, fn, args):
if fn.noinline:
for idx, arg in enumerate(args):
if not _is_constexpr(arg) and not _is_triton_scalar(arg):
raise UnsupportedLanguageConstruct(fn.src, node, f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}')
raise UnsupportedLanguageConstruct(
fn.src, node,
f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}'
)
def _get_fn_file_line(fn):
@@ -89,6 +91,7 @@ _condition_types = {bool, int, type(None)} # Python types accepted for conditio
class enter_sub_region:
def __init__(self, generator):
self.generator = generator
@@ -109,6 +112,7 @@ class enter_sub_region:
# Check if the given syntax node has an "early" return
class ContainsReturnChecker(ast.NodeVisitor):
def __init__(self, gscope):
self.gscope = gscope
@@ -199,9 +203,10 @@ class ContainsReturnChecker(ast.NodeVisitor):
class CodeGenerator(ast.NodeVisitor):
def __init__(self, context, prototype, gscope, attributes, constants, function_name, target,
module=None, is_kernel=False, function_types: Optional[Dict] = None,
debug=False, noinline=False, file_name: Optional[str] = None, begin_line=0):
def __init__(self, context, prototype, gscope, attributes, constants, function_name, target, module=None,
is_kernel=False, function_types: Optional[Dict] = None, debug=False, noinline=False,
file_name: Optional[str] = None, begin_line=0):
self.context = context
self.builder = ir.builder(context)
self.file_name = file_name
@@ -237,8 +242,10 @@ class CodeGenerator(ast.NodeVisitor):
))
def _define_name_lookup(self):
def local_lookup(name: str, absent):
value = self.lscope.get(name, absent) # this needs to be re-fetched from `self` every time, because it gets switched occasionally
# this needs to be re-fetched from `self` every time, because it gets switched occasionally
value = self.lscope.get(name, absent)
if value is not absent and name not in self.local_defs:
self.global_uses[name] = value
return value
@@ -255,8 +262,7 @@ class CodeGenerator(ast.NodeVisitor):
return name_lookup
def set_value(self, name: str,
value: Union[tensor, constexpr]) -> None:
def set_value(self, name: str, value: Union[tensor, constexpr]) -> None:
''' This function:
called by visit_Assign() & visit_FunctionDef() to store left value (lvalue)
1. record local defined name (FIXME: should consider control flow)
@@ -338,7 +344,8 @@ class CodeGenerator(ast.NodeVisitor):
self.visit(init_node)
# initialize function
visibility = "public" if self.is_kernel else "private"
self.fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility, self.noinline)
self.fn = self.builder.get_or_insert_function(self.module, self.function_name,
self.prototype.to_ir(self.builder), visibility, self.noinline)
self.module.push_back(self.fn)
entry = self.fn.add_entry_block()
arg_values = []
@@ -469,12 +476,23 @@ class CodeGenerator(ast.NodeVisitor):
rhs = self.visit(node.right)
method_name = self._method_name_for_bin_op.get(type(node.op))
if method_name is None:
raise UnsupportedLanguageConstruct(None, node, "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__))
raise UnsupportedLanguageConstruct(
None, node, "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__))
return self._apply_binary_method(method_name, lhs, rhs)
_method_name_for_bin_op: Dict[Type[ast.operator], str] = {
ast.Add: '__add__', ast.Sub: '__sub__', ast.Mult: '__mul__', ast.Div: '__truediv__',
ast.FloorDiv: '__floordiv__', ast.Mod: '__mod__', ast.Pow: '__pow__',
ast.LShift: '__lshift__', ast.RShift: '__rshift__', ast.BitAnd: '__and__', ast.BitOr: '__or__', ast.BitXor: '__xor__',
ast.Add: '__add__',
ast.Sub: '__sub__',
ast.Mult: '__mul__',
ast.Div: '__truediv__',
ast.FloorDiv: '__floordiv__',
ast.Mod: '__mod__',
ast.Pow: '__pow__',
ast.LShift: '__lshift__',
ast.RShift: '__rshift__',
ast.BitAnd: '__and__',
ast.BitOr: '__or__',
ast.BitXor: '__xor__',
}
def visit_then_else_blocks(self, node, liveins, then_block, else_block):
@@ -508,7 +526,8 @@ class CodeGenerator(ast.NodeVisitor):
if name in then_defs or name in else_defs:
names.append(name)
ret_types.append(then_defs[name].type if name in then_defs else else_defs[name].type)
ir_ret_types.append(then_defs[name].handle.get_type() if name in then_defs else else_defs[name].handle.get_type())
ir_ret_types.append(then_defs[name].handle.get_type() if name in
then_defs else else_defs[name].handle.get_type())
# variable defined in then but not in else
if name in then_defs and name not in else_defs:
else_defs[name] = liveins[name]
@@ -602,8 +621,7 @@ class CodeGenerator(ast.NodeVisitor):
contains_return = ContainsReturnChecker(self.gscope).visit(node)
if self.scf_stack and contains_return:
raise UnsupportedLanguageConstruct(
None, node,
"Cannot have `return` statements inside `while` or `for` statements in triton "
None, node, "Cannot have `return` statements inside `while` or `for` statements in triton "
"(note that this also applies to `return` statements that are inside functions "
"transitively called from within `while`/`for` statements)")
elif self.scf_stack or not contains_return:
@@ -612,10 +630,13 @@ class CodeGenerator(ast.NodeVisitor):
self.visit_if_top_level(cond, node)
else:
cond = _unwrap_if_constexpr(cond)
if type(cond) not in _condition_types: # not isinstance - we insist the real thing, no subclasses and no ducks
# not isinstance - we insist the real thing, no subclasses and no ducks
if type(cond) not in _condition_types:
raise UnsupportedLanguageConstruct(
None, node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
', '.join(_.__name__ for _ in _condition_types), type(cond).__name__))
None, node,
"`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
', '.join(_.__name__ for _ in _condition_types),
type(cond).__name__))
if cond:
self.visit_compound_statement(node.body)
else:
@@ -624,15 +645,52 @@ class CodeGenerator(ast.NodeVisitor):
def visit_IfExp(self, node):
cond = self.visit(node.test)
if _is_triton_tensor(cond):
raise UnsupportedLanguageConstruct(
None, node,
"Triton does not support `if` expressions (ternary operators) with dynamic conditions, use `if` statements instead")
cond = cond.to(language.int1, _builder=self.builder)
# TODO: Deal w/ more complicated return types (e.g tuple)
with enter_sub_region(self):
ip, last_loc = self._get_insertion_point_and_loc()
then_block = self.builder.create_block()
self.builder.set_insertion_point_to_start(then_block)
then_val = language.core._to_tensor(self.visit(node.body), self.builder)
then_block = self.builder.get_insertion_block()
else_block = self.builder.create_block()
self.builder.set_insertion_point_to_start(else_block)
# do not need to reset lscope since
# ternary expressions cannot define new variables
else_val = language.core._to_tensor(self.visit(node.orelse), self.builder)
else_block = self.builder.get_insertion_block()
self._set_insertion_point_and_loc(ip, last_loc)
assert then_val.type == else_val.type, \
f'ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}'
ret_type = then_val.type
ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else []
if_op = self.builder.create_if_op(ret_type_ir, cond.handle, True)
then_block.merge_block_before(if_op.get_then_block())
if ret_type_ir:
self.builder.set_insertion_point_to_end(if_op.get_then_block())
self.builder.create_yield_op([then_val.handle])
self.builder.set_insertion_point_to_end(if_op.get_then_block())
else_block.merge_block_before(if_op.get_else_block())
if ret_type_ir:
self.builder.set_insertion_point_to_end(if_op.get_else_block())
self.builder.create_yield_op([else_val.handle])
return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None
else:
cond = _unwrap_if_constexpr(cond)
if type(cond) not in _condition_types: # not isinstance - we insist the real thing, no subclasses and no ducks
# not isinstance - we insist the real thing, no subclasses and no ducks
if type(cond) not in _condition_types:
raise UnsupportedLanguageConstruct(
None, node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
', '.join(_.__name__ for _ in _condition_types), type(cond).__name__))
None, node,
"`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
', '.join(_.__name__ for _ in _condition_types),
type(cond).__name__))
if cond:
return self.visit(node.body)
else:
@@ -654,8 +712,10 @@ class CodeGenerator(ast.NodeVisitor):
return constexpr(lhs_value is not rhs_value)
method_name = self._method_name_for_comp_op.get(type(node.ops[0]))
if method_name is None:
raise UnsupportedLanguageConstruct(None, node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__))
raise UnsupportedLanguageConstruct(
None, node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__))
return self._apply_binary_method(method_name, lhs, rhs)
_method_name_for_comp_op: Dict[Type[ast.cmpop], str] = {
ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__'
}
@@ -664,11 +724,15 @@ class CodeGenerator(ast.NodeVisitor):
op = self.visit(node.operand)
fn = self._method_name_for_unary_op.get(type(node.op))
if fn is None:
raise UnsupportedLanguageConstruct(None, node, "AST unary operator '{}' is not (currently) implemented.".format(node.op.__name__))
raise UnsupportedLanguageConstruct(
None, node, "AST unary operator '{}' is not (currently) implemented.".format(node.op.__name__))
if _is_triton_tensor(op):
return getattr(op, fn)(_builder=self.builder)
return getattr(op, fn)()
_method_name_for_unary_op: Dict[Type[ast.unaryop], str] = {ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__'}
_method_name_for_unary_op: Dict[Type[ast.unaryop], str] = {
ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__'
}
def visit_While(self, node):
with enter_sub_region(self) as sr:
@@ -763,9 +827,7 @@ class CodeGenerator(ast.NodeVisitor):
iter_args = [self.visit(arg) for arg in node.iter.args]
if IteratorClass == language.static_range:
iterator = IteratorClass(*iter_args)
static_range = range(iterator.start.value,
iterator.end.value,
iterator.step.value)
static_range = range(iterator.start.value, iterator.end.value, iterator.step.value)
for i in static_range:
self.lscope[node.target.id] = constexpr(i)
self.visit_compound_statement(node.body)
@@ -902,8 +964,7 @@ class CodeGenerator(ast.NodeVisitor):
def call_JitFunction(self, fn: JITFunction, args, kwargs):
args = inspect.getcallargs(fn.fn, *args, **kwargs)
args = [args[name] for name in fn.arg_names]
args = [arg if _is_triton_tensor(arg)
else constexpr(arg) for arg in args]
args = [arg if _is_triton_tensor(arg) else constexpr(arg) for arg in args]
# generate function def
attributes = dict()
constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)]
@@ -921,8 +982,9 @@ class CodeGenerator(ast.NodeVisitor):
debug = self.debug if fn.debug is None else fn.debug
file_name, begin_line = _get_fn_file_line(fn)
generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module,
function_name=fn_name, function_types=self.function_ret_types, debug=debug, noinline=fn.noinline,
file_name=file_name, begin_line=begin_line, target=self.builder.target)
function_name=fn_name, function_types=self.function_ret_types, debug=debug,
noinline=fn.noinline, file_name=file_name, begin_line=begin_line,
target=self.builder.target)
generator.visit(fn.parse())
callee_ret_type = generator.last_ret_type
self.function_ret_types[fn_name] = callee_ret_type
@@ -950,7 +1012,7 @@ class CodeGenerator(ast.NodeVisitor):
kws = dict(self.visit(keyword) for keyword in node.keywords)
args = [self.visit(arg) for arg in node.args]
if fn is language.core.device_assert: # TODO: this should not be so hardcoded
if fn is language.core.device_assert: # TODO: this should not be so hardcoded
if not self.debug:
return
if isinstance(fn, JITFunction):
@@ -971,16 +1033,21 @@ class CodeGenerator(ast.NodeVisitor):
def visit_BoolOp(self, node: ast.BoolOp):
if len(node.values) != 2:
raise UnsupportedLanguageConstruct(None, node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.")
raise UnsupportedLanguageConstruct(
None, node,
"chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.")
lhs = self.visit(node.values[0])
rhs = self.visit(node.values[1])
method_name = self._method_name_for_bool_op.get(type(node.op))
if method_name is None:
raise UnsupportedLanguageConstruct(None, node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__))
raise UnsupportedLanguageConstruct(
None, node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__))
return self._apply_binary_method(method_name, lhs, rhs)
_method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'}
if sys.version_info < (3, 8):
def visit_NameConstant(self, node):
return constexpr(node.value)
@@ -1013,7 +1080,9 @@ class CodeGenerator(ast.NodeVisitor):
evaluated = self.visit(value.value)
if not _is_constexpr(evaluated):
raise UnsupportedLanguageConstruct(
None, node, "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type " + str(type(evaluated)))
None, node,
"Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type "
+ str(type(evaluated)))
values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value)
else:
raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value)))
@@ -1055,7 +1124,9 @@ class CodeGenerator(ast.NodeVisitor):
passed = _unwrap_if_constexpr(self.visit(node.args[0]))
if not isinstance(passed, bool):
raise NotImplementedError("Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values")
raise NotImplementedError(
"Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values"
)
if not passed:
if arg_count == 1:
message = ""
@@ -1144,10 +1215,9 @@ def ast_to_ttir(fn, signature, specialization, constants, debug, target):
file_name, begin_line = _get_fn_file_line(fn)
prototype = language.function_type([], arg_types)
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants,
function_name=function_name, attributes=new_attrs,
is_kernel=True, debug=debug, file_name=file_name, begin_line=begin_line,
target=target)
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name,
attributes=new_attrs, is_kernel=True, debug=debug, file_name=file_name,
begin_line=begin_line, target=target)
try:
generator.visit(fn.parse())
except CompilationError as e:

View File

@@ -11,25 +11,21 @@ from typing import Any
from dataclasses import dataclass
from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs,
compile_ptx_to_cubin, get_env_vars, get_num_warps,
get_shared_memory_size, ir, runtime,
translate_llvmir_to_ptx,
from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs, compile_ptx_to_cubin, get_env_vars,
get_num_warps, get_shared_memory_size, ir, runtime, translate_llvmir_to_ptx,
translate_triton_gpu_to_llvmir)
from ..common.backend import get_backend, path_to_ptxas
from ..common.backend import get_backend, get_cuda_version_key, path_to_ptxas
from ..common.build import is_hip
# from ..runtime import driver, jit, JITFunction
# TODO: runtime.errors
from ..runtime.autotuner import OutOfResources
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
from ..runtime.driver import driver
from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device,
get_device_capability, version_key)
from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device, get_device_capability)
from ..tools.disasm import get_sass
from .code_generator import ast_to_ttir
from .make_launcher import make_stub
from .utils import (InfoFromBackendForTensorMap, TensorMapManager,
get_ids_of_tensormaps, parse_tma_info)
from .utils import (InfoFromBackendForTensorMap, TensorMapManager, get_ids_of_tensormaps, parse_tma_info)
CUDA_DEFAULT_WARP_SIZE = 32
@@ -45,6 +41,7 @@ def _is_cuda(target):
class LazyDict(dict):
def __getitem__(self, key):
val = dict.__getitem__(self, key)
if callable(val):
@@ -102,8 +99,8 @@ def ttir_to_ttgir(mod, num_warps, warpsize, num_ctas, target):
return mod
def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target,
cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, matrix_inst_type):
def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization,
enable_persistent, optimize_epilogue, matrix_inst_type):
is_cuda = _is_cuda(target)
if is_cuda:
capability = target.capability
@@ -173,6 +170,8 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target,
if is_cuda and capability // 10 >= 9:
pm.add_tritongpu_fence_insertion_pass()
pm.add_tritongpu_ws_fixup_missing_attrs_pass()
pm.add_tritongpu_optimize_thread_locality_pass()
pm.add_canonicalizer_pass()
pm.run(mod)
return mod
@@ -196,6 +195,7 @@ def ttgir_to_llir(mod, extern_libs, target, tma_infos, waves_per_eu=0):
# PTX translation
@functools.lru_cache()
def ptx_get_version(cuda_version) -> int:
'''
@@ -260,7 +260,11 @@ def convert_type_repr(x):
return x
def make_hash(fn, target, env_vars, **kwargs):
def make_hash(fn, target, env_vars, device_backend, **kwargs):
if device_backend is None:
version_key = get_cuda_version_key()
else:
version_key = device_backend.get_version_key()
if isinstance(fn, JITFunction):
configs = kwargs["configs"]
signature = kwargs["signature"]
@@ -274,16 +278,17 @@ def make_hash(fn, target, env_vars, **kwargs):
enable_persistent = kwargs.get("enable_persistent", False)
debug = kwargs.get("debug", False)
# Get unique key for the compiled code
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1), sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8))
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1),
sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8))
configs_key = [get_conf_key(conf) for conf in configs]
env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())]
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{matrix_instr_nonkdim}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}"
key = f"{fn.cache_key}-{version_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{matrix_instr_nonkdim}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}"
return hashlib.md5(key.encode("utf-8")).hexdigest()
assert isinstance(fn, str)
ignore_version = kwargs.get('ignore_version', False)
if (ignore_version):
return hashlib.md5((Path(fn).read_text()).encode("utf-8")).hexdigest()
return hashlib.md5((Path(fn).read_text() + version_key()).encode("utf-8")).hexdigest()
return hashlib.md5((Path(fn).read_text() + version_key).encode("utf-8")).hexdigest()
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
@@ -320,12 +325,14 @@ else:
def _get_jsonable_constants(constants):
def _is_jsonable(x):
try:
json.dumps(x)
return True
except (TypeError, OverflowError):
return False
serialized_constants = {}
for constant in constants:
if _is_jsonable(constants[constant]):
@@ -340,7 +347,9 @@ def parse_mlir_module(path, context):
return module
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"], defaults=[set(), set(), set(), set()])
instance_descriptor = namedtuple("instance_descriptor",
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"],
defaults=[set(), set(), set(), set()])
def is_hip():
@@ -382,10 +391,9 @@ def get_arch_default_num_stages(device_type, capability=None):
def add_cuda_stages(target, extern_libs, stages):
stages["ptx"] = (lambda path: Path(path).read_text(),
lambda src: llir_to_ptx(src, target))
stages["cubin"] = (lambda path: Path(path).read_bytes(),
lambda src: ptx_to_cubin(src, target))
stages["ptx"] = (lambda path: Path(path).read_text(), lambda src: llir_to_ptx(src, target))
stages["cubin"] = (lambda path: Path(path).read_bytes(), lambda src: ptx_to_cubin(src, target))
def compile(fn, **kwargs):
@@ -431,7 +439,8 @@ def compile(fn, **kwargs):
# build architecture descriptor
if device_type == "cuda":
_device_backend = get_backend(device_type)
target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps, enable_fp_fusion=enable_fp_fusion)
target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps,
enable_fp_fusion=enable_fp_fusion)
else:
_device_backend = get_backend(device_type)
assert _device_backend
@@ -440,11 +449,12 @@ def compile(fn, **kwargs):
# build compilation stages
stages = dict()
stages["ast"] = (lambda path: fn, None)
stages["ttir"] = (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target))
stages["ttir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttir(
ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target))
if is_cuda:
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
stages["ttgir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttgir(
ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info,
enable_warp_specialization, enable_persistent, optimize_epilogue))
stages["llir"] = (lambda path: Path(path).read_text(),
lambda src: ttgir_to_llir(src, extern_libs, target, tma_infos))
add_cuda_stages(target, extern_libs, stages)
@@ -504,18 +514,21 @@ def compile(fn, **kwargs):
if ir_name == 'ttgir':
num_warps_matches = re.findall(ttgir_num_warps_pattern, src)
assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps"
assert "num_warps" not in kwargs or int(num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile"
assert "num_warps" not in kwargs or int(
num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile"
num_warps = int(num_warps_matches[0])
param_tys = [convert_type_repr(ty) for ty in types]
signature = {k: v for k, v in enumerate(param_tys)}
first_stage = list(stages.keys()).index(ir_name)
# create cache manager
fn_cache_manager = get_cache_manager(make_hash(fn, target, get_env_vars(), **kwargs))
fn_cache_manager = get_cache_manager(make_hash(fn, target, get_env_vars(), _device_backend, **kwargs))
# managers used to dump and override IR for debugging
enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1"
fn_override_manager = get_override_manager(make_hash(fn, target, get_env_vars(), **kwargs, ignore_version=True))
fn_dump_manager = get_dump_manager(make_hash(fn, target, get_env_vars(), **kwargs, ignore_version=True))
fn_override_manager = get_override_manager(
make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True))
fn_dump_manager = get_dump_manager(
make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True))
# determine name and extension type of provided function
if isinstance(fn, JITFunction):
@@ -528,9 +541,7 @@ def compile(fn, **kwargs):
metadata_filename = f"{name}.json"
# The group is addressed by the metadata
metadata_group = fn_cache_manager.get_group(
metadata_filename
) or {}
metadata_group = fn_cache_manager.get_group(metadata_filename) or {}
metadata_path = metadata_group.get(metadata_filename)
@@ -538,20 +549,21 @@ def compile(fn, **kwargs):
with open(metadata_path) as f:
metadata = json.load(f)
if 'tensormaps_info' in metadata:
metadata['tensormaps_info'] = [
InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']]
metadata['tensormaps_info'] = [InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']]
else:
metadata = {"num_warps": num_warps,
"warp_size": warp_size,
"num_ctas": num_ctas,
"num_stages": num_stages,
"waves_per_eu": waves_per_eu,
"matrix_instr_nonkdim": matrix_instr_nonkdim,
"enable_warp_specialization": enable_warp_specialization,
"enable_persistent": enable_persistent,
"constants": _get_jsonable_constants(constants),
"debug": debug,
"target": target, }
metadata = {
"num_warps": num_warps,
"warp_size": warp_size,
"num_ctas": num_ctas,
"num_stages": num_stages,
"waves_per_eu": waves_per_eu,
"matrix_instr_nonkdim": matrix_instr_nonkdim,
"enable_warp_specialization": enable_warp_specialization,
"enable_persistent": enable_persistent,
"constants": _get_jsonable_constants(constants),
"debug": debug,
"target": target,
}
metadata.update(get_env_vars())
if ext == "ptx":
assert "shared" in kwargs, "ptx compilation must provide shared memory size"
@@ -623,10 +635,7 @@ def compile(fn, **kwargs):
ids_of_folded_args = tuple([int(k) for k in configs[0].ids_of_folded_args]) if isinstance(fn, JITFunction) else ()
if "clusterDims" not in metadata:
metadata["clusterDims"] = [
cluster_info.clusterDimX,
cluster_info.clusterDimY,
cluster_info.clusterDimZ]
metadata["clusterDims"] = [cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ]
if len(tma_infos) > 0:
metadata["tensormaps_info"] = parse_tma_info(tma_infos, ids_of_folded_args)
@@ -640,7 +649,10 @@ def compile(fn, **kwargs):
fn.tensormaps_info = metadata["tensormaps_info"]
ids_of_const_exprs = tuple(fn.constexprs) if isinstance(fn, JITFunction) else ()
ids = {"ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs": ids_of_const_exprs}
ids = {
"ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs":
ids_of_const_exprs
}
# cache manager
if is_cuda:
so_path = make_stub(name, signature, constants, ids, enable_warp_specialization=enable_warp_specialization)
@@ -648,7 +660,8 @@ def compile(fn, **kwargs):
so_path = _device_backend.make_launcher_stub(name, signature, constants, ids)
# write-back metadata, if it didn't come from the cache
if metadata_path is None:
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, binary=False)
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename,
binary=False)
fn_cache_manager.put_group(metadata_filename, metadata_group)
# return handle to compiled kernel
@@ -698,10 +711,7 @@ class CompiledKernel:
if self.device_type in ["cuda"]:
device = get_current_device()
bin_path = {
driver.HIP: "hsaco_path",
driver.CUDA: "cubin"
}[driver.backend]
bin_path = {driver.HIP: "hsaco_path", driver.CUDA: "cubin"}[driver.backend]
max_shared = driver.utils.get_device_properties(device)["max_shared_mem"]
fn_load_binary = driver.utils.load_binary
else:
@@ -749,4 +759,5 @@ class CompiledKernel:
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.num_ctas, self.clusterDims[0],
self.clusterDims[1], self.clusterDims[2], self.shared, stream, self.cu_function,
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand)
return runner

View File

@@ -3,9 +3,9 @@ import os
import tempfile
from ..common import _build
from ..common.backend import get_cuda_version_key
from ..common.build import is_hip
from ..runtime.cache import get_cache_manager
from ..runtime.jit import version_key
from .utils import generate_cu_signature
# ----- stub --------
@@ -23,7 +23,7 @@ def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
def make_stub(name, signature, constants, ids, **kwargs):
# name of files that are cached
so_cache_key = make_so_cache_key(version_key(), signature, constants, ids, **kwargs)
so_cache_key = make_so_cache_key(get_cuda_version_key(), signature, constants, ids, **kwargs)
so_cache_manager = get_cache_manager(so_cache_key)
so_name = f"{name}.so"
# retrieve stub from cache if it exists
@@ -40,6 +40,7 @@ def make_stub(name, signature, constants, ids, **kwargs):
else:
return cache_path
# ----- source code generation --------
@@ -100,7 +101,10 @@ def generate_launcher(constants, signature, ids):
# generate glue code
folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']]
params = [i for i in signature.keys() if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs)]
params = [
i for i in signature.keys()
if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs)
]
src = f"""
#include \"cuda.h\"
#include <stdbool.h>

View File

@@ -158,19 +158,21 @@ class InfoFromBackendForTensorMap:
# dtype:cuda.CUtensorMapDataType | int
def bytes_from_type(self, dtype):
return {driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT8"]: 1,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT16"]: 2,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT32"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT32"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT64"]: 8,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT64"]: 8,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT16"]: 2,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT64"]: 8,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_BFLOAT16"]: 2,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ"]: 4}[dtype]
return {
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT8"]: 1,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT16"]: 2,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT32"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT32"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT64"]: 8,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT64"]: 8,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT16"]: 2,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT64"]: 8,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_BFLOAT16"]: 2,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ"]: 4
}[dtype]
def getTensorMapDataType(self):
return self.tensorDataType
@@ -259,22 +261,29 @@ class InfoFromBackendForTensorMap:
self.getInterleave(),
self.getSwizzle(),
self.getL2Promotion(),
self.getOobFill()
self.getOobFill(),
)
# make hashable to use as partial key in cache
def __hash__(self):
return hash((self.ids_of_folded_args, self.globalAddressArgIdx, tuple(self.globalDimsArgIdx), tuple(self.globalStridesArgIdx), self.tensorDataType,
self.tensorRank, tuple(self.boxDims), tuple(self.elementStrides), self.interleave, self.swizzle, self.l2Promotion, self.oobFill))
return hash((self.ids_of_folded_args, self.globalAddressArgIdx, tuple(self.globalDimsArgIdx),
tuple(self.globalStridesArgIdx), self.tensorDataType, self.tensorRank, tuple(self.boxDims),
tuple(self.elementStrides), self.interleave, self.swizzle, self.l2Promotion, self.oobFill))
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
return (self.ids_of_folded_args, self.globalAddressArgIdx, self.globalDimsArgIdx, self.globalStridesArgIdx, self.tensorDataType, self.tensorRank, self.boxDims, self.elementStrides, self.interleave, self.swizzle, self.l2Promotion, self.oobFill) == (
other.ids_of_folded_args, other.globalAddressArgIdx, other.globalDimsArgIdx, other.globalStridesArgIdx, other.tensorDataType, other.tensorRank, other.boxDims, other.elementStrides, other.interleave, other.swizzle, other.l2Promotion, other.oobFill)
return (self.ids_of_folded_args, self.globalAddressArgIdx, self.globalDimsArgIdx, self.globalStridesArgIdx,
self.tensorDataType, self.tensorRank, self.boxDims, self.elementStrides, self.interleave, self.swizzle,
self.l2Promotion,
self.oobFill) == (other.ids_of_folded_args, other.globalAddressArgIdx, other.globalDimsArgIdx,
other.globalStridesArgIdx, other.tensorDataType, other.tensorRank, other.boxDims,
other.elementStrides, other.interleave, other.swizzle, other.l2Promotion,
other.oobFill)
class TensorMapManager:
def __init__(self):
self.tensormaps_device = {}
@@ -286,8 +295,7 @@ class TensorMapManager:
t_tensormap = e.tensormap(args)
TENSORMAP_SIZE_IN_BYTES = 128
t_tensormap_device = driver.utils.cuMemAlloc(TENSORMAP_SIZE_IN_BYTES)
driver.utils.cuMemcpyHtoD(
t_tensormap_device, t_tensormap, TENSORMAP_SIZE_IN_BYTES)
driver.utils.cuMemcpyHtoD(t_tensormap_device, t_tensormap, TENSORMAP_SIZE_IN_BYTES)
self.tensormaps_device[key] = t_tensormap_device
return int(self.tensormaps_device[key])

View File

@@ -111,7 +111,6 @@ from .random import (
uint32_to_uniform_float,
)
__all__ = [
"TRITON_MAX_TENSOR_NUMEL",
"abs",

View File

@@ -22,10 +22,8 @@ def builtin(fn: T) -> T:
@wraps(fn)
def wrapper(*args, **kwargs):
if "_builder" not in kwargs or kwargs["_builder"] is None:
raise ValueError(
"Did you forget to add @triton.jit ? "
"(`_builder` argument must be provided outside of JIT functions.)"
)
raise ValueError("Did you forget to add @triton.jit ? "
"(`_builder` argument must be provided outside of JIT functions.)")
return fn(*args, **kwargs)
setattr(wrapper, TRITON_BUILTIN, True)
@@ -54,7 +52,7 @@ def _to_tensor(x, builder):
else:
raise RuntimeError(f'Nonrepresentable integer {x}.')
elif isinstance(x, float):
min_float32 = 2 ** -126
min_float32 = 2**-126
max_float32 = (2 - 2**-23) * 2**127
abs_x = __builtins__['abs'](x)
if abs_x == float("inf") or\
@@ -243,7 +241,7 @@ class dtype:
return not self.__eq__(other)
def __hash__(self):
return hash((self.name,))
return hash((self.name, ))
@property
def scalar(self):
@@ -297,6 +295,7 @@ class dtype:
class pointer_type(dtype):
def __init__(self, element_ty: dtype, address_space: int = 1):
if not isinstance(element_ty, dtype):
raise TypeError('element_ty is a {type(element_ty).__name__}.')
@@ -331,6 +330,7 @@ class pointer_type(dtype):
class block_type(dtype):
def __init__(self, element_ty: dtype, shape: List):
self.element_ty = element_ty
@@ -381,6 +381,7 @@ class block_type(dtype):
class function_type(dtype):
def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None:
self.ret_types = ret_types
self.param_types = param_types
@@ -531,7 +532,7 @@ class constexpr:
return constexpr(~self.value)
def __pow__(self, other):
return constexpr(self.value ** other.value)
return constexpr(self.value**other.value)
def __rshift__(self, other):
return constexpr(self.value >> other.value)
@@ -547,6 +548,7 @@ class constexpr:
class tensor:
def __init__(self, handle, type: dtype):
# IR handle
self.handle = handle
@@ -740,11 +742,21 @@ class tensor:
other = _to_tensor(other, _builder)
return semantic.equal(self, other, _builder)
@builtin
def __req__(self, other, _builder=None):
other = _to_tensor(other, _builder)
return semantic.equal(other, self, _builder)
@builtin
def __ne__(self, other, _builder=None):
other = _to_tensor(other, _builder)
return semantic.not_equal(self, other, _builder)
@builtin
def __rne__(self, other, _builder=None):
other = _to_tensor(other, _builder)
return semantic.not_equal(other, self, _builder)
@builtin
def logical_and(self, other, _builder=None):
other = _to_tensor(other, _builder)
@@ -1023,6 +1035,7 @@ def expand_dims(input, axis, _builder=None):
ret = semantic.expand_dims(ret, a, _builder)
return ret
# -----------------------
# Linear Algebra
# -----------------------
@@ -1171,6 +1184,7 @@ def advance(base: tensor, offsets, _builder=None):
"""
return semantic.advance(base, offsets, _builder)
# -----------------------
# Atomic Memory Operations
# -----------------------
@@ -1196,6 +1210,9 @@ def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]:
:param sem: Memory semantics to use ("ACQUIRE_RELEASE" (default),
"ACQUIRE", "RELEASE", or "RELAXED")
:type sem: str
:param scope: Scope of threads that observe synchronizing effect of the
atomic operation ("GPU" (default), "CTA", or "SYSTEM")
:type scope: str
"""
func.__doc__ = docstr
return func
@@ -1205,73 +1222,82 @@ def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]:
@builtin
@_add_atomic_docstr("compare-and-swap", has_cmp=True)
def atomic_cas(pointer, cmp, val, sem=None, _builder=None):
def atomic_cas(pointer, cmp, val, sem=None, scope=None, _builder=None):
cmp = _to_tensor(cmp, _builder)
val = _to_tensor(val, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_cas(pointer, cmp, val, sem, _builder)
scope = _constexpr_to_value(scope)
return semantic.atomic_cas(pointer, cmp, val, sem, scope, _builder)
@builtin
@_add_atomic_docstr("exchange")
def atomic_xchg(pointer, val, mask=None, sem=None, _builder=None):
def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _builder=None):
val = _to_tensor(val, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_xchg(pointer, val, mask, sem, _builder)
scope = _constexpr_to_value(scope)
return semantic.atomic_xchg(pointer, val, mask, sem, scope, _builder)
@builtin
@_add_atomic_docstr("add")
def atomic_add(pointer, val, mask=None, sem=None, _builder=None):
def atomic_add(pointer, val, mask=None, sem=None, scope=None, _builder=None):
val = _to_tensor(val, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_add(pointer, val, mask, sem, _builder)
scope = _constexpr_to_value(scope)
return semantic.atomic_add(pointer, val, mask, sem, scope, _builder)
@builtin
@_add_atomic_docstr("max")
def atomic_max(pointer, val, mask=None, sem=None, _builder=None):
def atomic_max(pointer, val, mask=None, sem=None, scope=None, _builder=None):
val = _to_tensor(val, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_max(pointer, val, mask, sem, _builder)
scope = _constexpr_to_value(scope)
return semantic.atomic_max(pointer, val, mask, sem, scope, _builder)
@builtin
@_add_atomic_docstr("min")
def atomic_min(pointer, val, mask=None, sem=None, _builder=None):
def atomic_min(pointer, val, mask=None, sem=None, scope=None, _builder=None):
val = _to_tensor(val, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_min(pointer, val, mask, sem, _builder)
scope = _constexpr_to_value(scope)
return semantic.atomic_min(pointer, val, mask, sem, scope, _builder)
@builtin
@_add_atomic_docstr("logical and")
def atomic_and(pointer, val, mask=None, sem=None, _builder=None):
def atomic_and(pointer, val, mask=None, sem=None, scope=None, _builder=None):
val = _to_tensor(val, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_and(pointer, val, mask, sem, _builder)
scope = _constexpr_to_value(scope)
return semantic.atomic_and(pointer, val, mask, sem, scope, _builder)
@builtin
@_add_atomic_docstr("logical or")
def atomic_or(pointer, val, mask=None, sem=None, _builder=None):
def atomic_or(pointer, val, mask=None, sem=None, scope=None, _builder=None):
val = _to_tensor(val, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_or(pointer, val, mask, sem, _builder)
scope = _constexpr_to_value(scope)
return semantic.atomic_or(pointer, val, mask, sem, scope, _builder)
@builtin
@_add_atomic_docstr("logical xor")
def atomic_xor(pointer, val, mask=None, sem=None, _builder=None):
def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _builder=None):
val = _to_tensor(val, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_xor(pointer, val, mask, sem, _builder)
scope = _constexpr_to_value(scope)
return semantic.atomic_xor(pointer, val, mask, sem, scope, _builder)
# -----------------------
# Conditioning
# -----------------------
@builtin
def where(condition, x, y, _builder=None):
"""
@@ -1299,6 +1325,7 @@ def where(condition, x, y, _builder=None):
# Math
# -----------------------
@builtin
def umulhi(x, y, _builder=None):
"""
@@ -1392,6 +1419,7 @@ def abs(x, _builder=None):
# Reductions
# -----------------------
def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]:
def _decorator(func: T) -> T:
@@ -1430,8 +1458,7 @@ def reduce(input, axis, combine_fn, _builder=None, _generator=None):
"""
if isinstance(input, tensor):
return reduce((input,), axis, combine_fn,
_builder=_builder, _generator=_generator)[0]
return reduce((input, ), axis, combine_fn, _builder=_builder, _generator=_generator)[0]
def make_combine_region(reduce_op):
in_scalar_tys = [t.type.scalar for t in input]
@@ -1441,14 +1468,14 @@ def reduce(input, axis, combine_fn, _builder=None, _generator=None):
with _insertion_guard(_builder):
param_types = [ty.to_ir(_builder) for ty in prototype.param_types]
block = _builder.create_block_with_parent(region, param_types)
args = [tensor(block.arg(i), ty)
for i, ty in enumerate(prototype.param_types)]
args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)]
results = _generator.call_JitFunction(combine_fn, args, kwargs={})
if isinstance(results, tensor):
handles = [results.handle]
else:
handles = [r.handle for r in results]
_builder.create_reduce_ret(*handles)
if axis is not None:
axis = _constexpr_to_value(axis)
return semantic.reduction(input, axis, make_combine_region, _builder)
@@ -1483,8 +1510,7 @@ def _reduce_with_indices(input, axis, combine_fn, _builder=None, _generator=None
index = expand_dims(index, axes_to_expand, _builder=_builder)
index = broadcast_to(index, input.shape, _builder=_builder)
rvalue, rindices = reduce((input, index), axis, combine_fn,
_builder=_builder, _generator=_generator)
rvalue, rindices = reduce((input, index), axis, combine_fn, _builder=_builder, _generator=_generator)
return rvalue, rindices
@@ -1492,6 +1518,7 @@ def _reduce_with_indices(input, axis, combine_fn, _builder=None, _generator=None
# Scans
# -----------------------
def _add_scan_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]:
def _decorator(func: T) -> T:
@@ -1516,8 +1543,7 @@ def associative_scan(input, axis, combine_fn, _builder=None, _generator=None):
"""
if isinstance(input, tensor):
return associative_scan((input,), axis, combine_fn,
_builder=_builder, _generator=_generator)[0]
return associative_scan((input, ), axis, combine_fn, _builder=_builder, _generator=_generator)[0]
def make_combine_region(scan_op):
in_scalar_tys = [t.type.scalar for t in input]
@@ -1527,17 +1553,18 @@ def associative_scan(input, axis, combine_fn, _builder=None, _generator=None):
with _insertion_guard(_builder):
param_types = [ty.to_ir(_builder) for ty in prototype.param_types]
block = _builder.create_block_with_parent(region, param_types)
args = [tensor(block.arg(i), ty)
for i, ty in enumerate(prototype.param_types)]
args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)]
results = _generator.call_JitFunction(combine_fn, args, kwargs={})
if isinstance(results, tensor):
handles = [results.handle]
else:
handles = [r.handle for r in results]
_builder.create_scan_ret(*handles)
axis = _constexpr_to_value(axis)
return semantic.associative_scan(input, axis, make_combine_region, _builder)
# -----------------------
# Compiler Hint Ops
# -----------------------
@@ -1600,6 +1627,8 @@ def max_constancy(input, values, _builder=None):
raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
values = [x.value for x in values]
return semantic.max_constancy(input, values)
# -----------------------
# Debugging functions
# -----------------------
@@ -1739,12 +1768,12 @@ def inline_asm_elementwise(asm: str, constraints: str, args: list, dtype, is_pur
broadcast_arg = dispatch_args[0]
# Get the broadcast shape over all the arguments
for i, item in enumerate(dispatch_args):
_, broadcast_arg = semantic.binary_op_type_checking_impl(
item, broadcast_arg, _builder, arithmetic_check=False)
_, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder,
arithmetic_check=False)
# Change the shape of each argument based on the broadcast shape
for i in range(len(dispatch_args)):
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
dispatch_args[i], broadcast_arg, _builder, arithmetic_check=False)
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder,
arithmetic_check=False)
ret_shape = broadcast_arg.shape
res_ty = block_type(dtype, ret_shape)
call = _builder.create_inline_asm(asm, constraints, [t.handle for t in args], res_ty.to_ir(_builder), is_pure, pack)
@@ -1757,7 +1786,6 @@ def inline_asm_elementwise(asm: str, constraints: str, args: list, dtype, is_pur
class static_range:
"""
Iterator that counts upward forever.
@@ -1801,7 +1829,9 @@ class static_range:
# Extern functions
# -----------------------
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, is_pure: bool, _builder=None):
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple,
is_pure: bool, _builder=None):
'''
Dispatch a function to a library
:param func: the function to dispatch
@@ -1843,7 +1873,8 @@ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dic
return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder), is_pure), ret_type)
def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, _builder=None):
def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool,
_builder=None):
'''
Dispatch an elementwise function to a library
:param lib_name: the name of the library
@@ -1872,12 +1903,12 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol
broadcast_arg = dispatch_args[0]
# Get the broadcast shape over all the arguments
for i, item in enumerate(dispatch_args):
_, broadcast_arg = semantic.binary_op_type_checking_impl(
item, broadcast_arg, _builder, arithmetic_check=arithmetic_check)
_, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder,
arithmetic_check=arithmetic_check)
# Change the shape of each argument based on the broadcast shape
for i in range(len(dispatch_args)):
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
dispatch_args[i], broadcast_arg, _builder, arithmetic_check=arithmetic_check)
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder,
arithmetic_check=arithmetic_check)
if not all_scalar:
ret_shape = broadcast_arg.shape
func = getattr(_builder, "create_extern_elementwise")

View File

@@ -3,16 +3,14 @@ from .. import core
@core.extern
def globaltimer(_builder=None):
return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [],
dtype=core.int64, is_pure=False,
pack=1, _builder=_builder)
return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], dtype=core.int64, is_pure=False, pack=1,
_builder=_builder)
@core.extern
def smid(_builder=None):
return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [],
dtype=core.int32, is_pure=True,
pack=1, _builder=_builder)
return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1,
_builder=_builder)
@core.builtin

File diff suppressed because it is too large Load Diff

View File

@@ -91,6 +91,7 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
# two_to_the_minus_32: tl.constexpr = 2.328306e-10
# return x * two_to_the_minus_32
@jit
def uint32_to_uniform_float(x):
"""
@@ -134,6 +135,7 @@ def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
u4 = uint32_to_uniform_float(i4)
return u1, u2, u3, u4
# -------------------
# randn
# -------------------

View File

@@ -19,10 +19,12 @@ def _is_cuda(target):
from ..compiler.compiler import CudaTargetDescriptor
return isinstance(target, CudaTargetDescriptor)
# Create custom exception that prints message "hello"
class IncompatibleTypeErrorImpl(Exception):
def __init__(self, type_a, type_b):
self.type_a = type_a
self.type_b = type_b
@@ -34,6 +36,7 @@ class IncompatibleTypeErrorImpl(Exception):
# Programming Model
# ===----------------------------------------------------------------------===##
def program_id(axis: int, builder: ir.builder) -> tl.tensor:
if axis not in (0, 1, 2):
raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}")
@@ -45,6 +48,7 @@ def num_programs(axis: int, builder: ir.builder) -> tl.tensor:
raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}")
return tl.tensor(builder.create_get_num_programs(axis), tl.int32)
# ===----------------------------------------------------------------------===//
# Implicit Casting Utilities
# ===----------------------------------------------------------------------===//
@@ -95,10 +99,12 @@ def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> t
# 5 ) both operands are integer and undergo
# integer promotion
if div_or_mod and a_ty.int_signedness != b_ty.int_signedness:
raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + " because they have different signedness;"
raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() +
" because they have different signedness;"
"this is unlikely to result in a useful answer. Cast them to the same signedness.")
return integer_promote_impl(a_ty, b_ty)
# ===----------------------------------------------------------------------===//
# Binary Operators
# ===----------------------------------------------------------------------===//
@@ -116,12 +122,9 @@ def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -
raise IncompatibleTypeErrorImpl(type_a, type_b)
def binary_op_type_checking_impl(lhs: tl.tensor,
rhs: tl.tensor,
builder: ir.builder,
allow_lhs_ptr=False, allow_rhs_ptr=False,
arithmetic_check=True, div_or_mod=False
) -> Tuple[tl.tensor, tl.tensor]:
def binary_op_type_checking_impl(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder, allow_lhs_ptr=False,
allow_rhs_ptr=False, arithmetic_check=True,
div_or_mod=False) -> Tuple[tl.tensor, tl.tensor]:
# implicit broadcasting
lhs, rhs = broadcast_impl_value(lhs, rhs, builder)
# implicit typecasting
@@ -136,9 +139,7 @@ def binary_op_type_checking_impl(lhs: tl.tensor,
return lhs, rhs
def add(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def add(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder, True, True)
input_scalar_ty = input.type.scalar
other_scalar_ty = other.type.scalar
@@ -162,15 +163,12 @@ def add(input: tl.tensor,
assert False
def sub(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def sub(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder, True, False)
scalar_ty = input.type.scalar
# ptr - offset
if scalar_ty.is_ptr():
return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle),
input.type)
return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle), input.type)
# float - float
if scalar_ty.is_floating():
return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type)
@@ -180,9 +178,7 @@ def sub(input: tl.tensor,
assert False
def mul(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def mul(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float * float
@@ -194,9 +190,7 @@ def mul(input: tl.tensor,
assert False
def truediv(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def truediv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
input_scalar_ty = input.type.scalar
other_scalar_ty = other.type.scalar
@@ -222,9 +216,7 @@ def truediv(input: tl.tensor,
return tl.tensor(builder.create_fdiv(input.handle, other.handle), input.type)
def floordiv(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def floordiv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
input_scalar_ty = input.type.scalar
other_scalar_ty = other.type.scalar
@@ -239,10 +231,7 @@ def floordiv(input: tl.tensor,
assert False
def fdiv(input: tl.tensor,
other: tl.tensor,
ieee_rounding: bool,
builder: ir.builder) -> tl.tensor:
def fdiv(input: tl.tensor, other: tl.tensor, ieee_rounding: bool, builder: ir.builder) -> tl.tensor:
input_scalar_ty = input.type.scalar
other_scalar_ty = other.type.scalar
if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating():
@@ -252,18 +241,14 @@ def fdiv(input: tl.tensor,
return tl.tensor(ret, input.type)
def mod(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def mod(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
scalar_ty = input.type.scalar
other_scalar_ty = other.type.scalar
# float % float
if scalar_ty.is_floating():
# input - input.div(other, rounding_mode="floor") * other
ret = sub(input, mul(floor(fdiv(input, other, False, builder), builder),
other, builder),
builder)
ret = sub(input, mul(floor(fdiv(input, other, False, builder), builder), other, builder), builder)
return ret
# % int
elif scalar_ty.is_int():
@@ -277,13 +262,13 @@ def mod(input: tl.tensor,
return tl.tensor(builder.create_urem(input.handle, other.handle), input.type)
assert False
##############
# bitwise ops
##############
def bitwise_op_type_checking_impl(input: tl.tensor,
other: tl.tensor,
def bitwise_op_type_checking_impl(input: tl.tensor, other: tl.tensor,
builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]:
input, other = binary_op_type_checking_impl(input, other, builder, False, False, False)
input_sca_ty = input.type.scalar
@@ -298,23 +283,17 @@ def bitwise_op_type_checking_impl(input: tl.tensor,
return input, other
def and_(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def and_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = bitwise_op_type_checking_impl(input, other, builder)
return tl.tensor(builder.create_and(input.handle, other.handle), input.type)
def or_(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def or_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = bitwise_op_type_checking_impl(input, other, builder)
return tl.tensor(builder.create_or(input.handle, other.handle), input.type)
def xor_(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def xor_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = bitwise_op_type_checking_impl(input, other, builder)
return tl.tensor(builder.create_xor(input.handle, other.handle), input.type)
@@ -341,26 +320,21 @@ def not_(input: tl.tensor, builder: ir.builder):
return invert(input, builder)
def lshr(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def lshr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = bitwise_op_type_checking_impl(input, other, builder)
return tl.tensor(builder.create_lshr(input.handle, other.handle), input.type)
def ashr(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def ashr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = bitwise_op_type_checking_impl(input, other, builder)
return tl.tensor(builder.create_ashr(input.handle, other.handle), input.type)
def shl(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def shl(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = bitwise_op_type_checking_impl(input, other, builder)
return tl.tensor(builder.create_shl(input.handle, other.handle), input.type)
# ===----------------------------------------------------------------------===//
# Unary Operators
# ===----------------------------------------------------------------------===//
@@ -370,8 +344,7 @@ def plus(input: tl.tensor) -> tl.tensor:
return input
def minus(input: tl.tensor,
builder: ir.builder) -> tl.tensor:
def minus(input: tl.tensor, builder: ir.builder) -> tl.tensor:
input_sca_ty = input.type.scalar
if input_sca_ty.is_ptr():
raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")")
@@ -379,8 +352,7 @@ def minus(input: tl.tensor,
return sub(_0, input, builder)
def invert(input: tl.tensor,
builder: tl.tensor) -> tl.tensor:
def invert(input: tl.tensor, builder: tl.tensor) -> tl.tensor:
input_sca_ty = input.type.scalar
if input_sca_ty.is_ptr() or input_sca_ty.is_floating():
raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")")
@@ -398,9 +370,7 @@ def _bool_like(v: tl.tensor) -> tl.block_type:
return tl.block_type(tl.int1, shape)
def greater_than(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def greater_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float > float
@@ -415,9 +385,7 @@ def greater_than(input: tl.tensor,
assert False
def greater_equal(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def greater_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float >= float
@@ -432,9 +400,7 @@ def greater_equal(input: tl.tensor,
assert False
def less_than(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def less_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float < float
@@ -449,9 +415,7 @@ def less_than(input: tl.tensor,
assert False
def less_equal(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def less_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float < float
@@ -466,9 +430,7 @@ def less_equal(input: tl.tensor,
assert False
def equal(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float == float
@@ -480,9 +442,7 @@ def equal(input: tl.tensor,
assert False
def not_equal(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def not_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float == float
@@ -493,6 +453,7 @@ def not_equal(input: tl.tensor,
return tl.tensor(builder.create_icmpNE(input.handle, other.handle), _bool_like(input))
assert False
# ===----------------------------------------------------------------------===//
# Block Creation
# ===----------------------------------------------------------------------===//
@@ -541,6 +502,7 @@ def ones(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
# Shape Manipulation
# ===----------------------------------------------------------------------===//
def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor:
assert not value.type.is_block(), "Cannot splat a block tensor"
if len(shape) == 0:
@@ -549,9 +511,7 @@ def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_splat(value.handle, shape), ret_ty)
def view(input: tl.tensor,
dst_shape: List[int],
builder: ir.builder) -> tl.tensor:
def view(input: tl.tensor, dst_shape: List[int], builder: ir.builder) -> tl.tensor:
# TODO: disable when TritonToTritonGPU handles views properly
# assert len(input.shape) == len(dst_shape)
@@ -564,9 +524,7 @@ def view(input: tl.tensor,
return tl.tensor(builder.create_view(input.handle, dst_shape), ret_ty)
def reshape(input: tl.tensor,
dst_shape: List[int],
builder: ir.builder) -> tl.tensor:
def reshape(input: tl.tensor, dst_shape: List[int], builder: ir.builder) -> tl.tensor:
raise ValueError("`reshape` is not supported yet. Please use `view` instead if applicable. "
"Note that view may reorder elements in an implementation- and context- dependent way.")
@@ -596,9 +554,7 @@ def trans(input: tl.tensor, builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_trans(input.handle), ret_type)
def broadcast_impl_shape(input: tl.tensor,
shape: List[int],
builder: ir.builder) -> tl.tensor:
def broadcast_impl_shape(input: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor:
if not input.type.is_block():
ret_ty = tl.block_type(input.type, shape)
return tl.tensor(builder.create_splat(input.handle, shape), ret_ty)
@@ -616,9 +572,7 @@ def broadcast_impl_shape(input: tl.tensor,
return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty)
def broadcast_impl_value(lhs: tl.tensor,
rhs: tl.tensor,
builder: ir.builder) -> tl.tensor:
def broadcast_impl_value(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor:
lhs_ty = lhs.type
rhs_ty = rhs.type
@@ -638,13 +592,15 @@ def broadcast_impl_value(lhs: tl.tensor,
if len(lhs_shape) < len(rhs_shape):
# Add new axes to lhs
for dim in range(len(lhs_shape), len(rhs_shape)):
lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0), tl.block_type(lhs_ty.scalar, [1] + lhs_shape))
lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0),
tl.block_type(lhs_ty.scalar, [1] + lhs_shape))
lhs_ty = lhs.type
lhs_shape = lhs_ty.get_block_shapes()
elif len(rhs_shape) < len(lhs_shape):
# Add new axes to rhs
for dim in range(len(rhs_shape), len(lhs_shape)):
rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0), tl.block_type(rhs_ty.scalar, [1] + rhs_shape))
rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0),
tl.block_type(rhs_ty.scalar, [1] + rhs_shape))
rhs_ty = rhs.type
rhs_shape = rhs_ty.get_block_shapes()
assert len(rhs_shape) == len(lhs_shape)
@@ -670,14 +626,13 @@ def broadcast_impl_value(lhs: tl.tensor,
# (scalar, scalar) => returns original blocks
return lhs, rhs
#######
# cast
#######
def bitcast(input: tl.tensor,
dst_ty: tl.dtype,
builder: ir.builder) -> tl.tensor:
def bitcast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor:
src_ty = input.type
if src_ty.is_block():
dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes())
@@ -693,13 +648,10 @@ def bitcast(input: tl.tensor,
if src_bits != dst_bits:
raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to "
"data-type of size " + str(dst_bits))
return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)),
dst_ty)
return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty)
def cast(input: tl.tensor,
dst_ty: tl.dtype,
builder: ir.builder) -> tl.tensor:
def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor:
src_ty = input.type
if isinstance(dst_ty, tl.constexpr):
dst_ty = dst_ty.value
@@ -718,8 +670,7 @@ def cast(input: tl.tensor,
# Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64
if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \
(src_sca_ty.is_floating() and dst_sca_ty.is_fp8()):
return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder)),
dst_ty)
return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty)
# bf16 <=> (not fp32)
if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \
@@ -733,9 +684,7 @@ def cast(input: tl.tensor,
dst_sca_ty.is_floating() and \
src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth
if truncate_fp:
return tl.tensor(builder.create_fp_trunc(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), dst_ty)
# Standard floating types' casting: extension
# fp32 => fp64
@@ -745,9 +694,7 @@ def cast(input: tl.tensor,
dst_sca_ty.is_floating() and \
src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth
if ext_fp:
return tl.tensor(builder.create_fp_ext(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), dst_ty)
# Casting between integer types
if src_sca_ty.is_int() and dst_sca_ty.is_int() and \
@@ -758,9 +705,7 @@ def cast(input: tl.tensor,
_0 = tl.tensor(builder.get_null_value(ty), input.dtype)
return not_equal(input, _0, builder)
else:
return tl.tensor(builder.create_int_cast(input.handle,
dst_ty.to_ir(builder), sign_extend),
dst_ty)
return tl.tensor(builder.create_int_cast(input.handle, dst_ty.to_ir(builder), sign_extend), dst_ty)
# Casting standard floating types to integer types
if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int():
@@ -769,35 +714,24 @@ def cast(input: tl.tensor,
_0 = tl.tensor(builder.get_null_value(ty), input.dtype)
return not_equal(input, _0, builder)
elif dst_sca_ty.is_int_signed():
return tl.tensor(builder.create_fp_to_si(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
return tl.tensor(builder.create_fp_to_si(input.handle, dst_ty.to_ir(builder)), dst_ty)
else:
return tl.tensor(builder.create_fp_to_ui(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
return tl.tensor(builder.create_fp_to_ui(input.handle, dst_ty.to_ir(builder)), dst_ty)
# Casting integer types to standard floating types
if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating():
if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed():
return tl.tensor(builder.create_ui_to_fp(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
return tl.tensor(builder.create_ui_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty)
else:
return tl.tensor(builder.create_si_to_fp(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
return tl.tensor(builder.create_si_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty)
# Casting pointer types to integer types
if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
bitwidth = dst_sca_ty.int_bitwidth
if bitwidth == 64:
return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)),
dst_ty)
return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty)
if bitwidth == 1:
return not_equal(cast(input, tl.int64, builder),
tl.tensor(builder.get_int64(0), tl.int64),
builder)
return not_equal(cast(input, tl.int64, builder), tl.tensor(builder.get_int64(0), tl.int64), builder)
# Casting integer types to pointer types
if src_sca_ty.is_int() and dst_sca_ty.is_ptr():
@@ -809,6 +743,7 @@ def cast(input: tl.tensor,
assert False, f'cannot cast {input} to {dst_ty}'
# ===----------------------------------------------------------------------===//
# Memory Operators
# ===----------------------------------------------------------------------===//
@@ -882,6 +817,20 @@ def _str_to_sem(sem_option):
return sem
def _str_to_scope(scope_option):
scope = ir.MEM_SYNC_SCOPE.GPU
if scope_option:
if scope_option == "gpu":
scope = ir.MEM_SYNC_SCOPE.GPU
elif scope_option == "cta":
scope = ir.MEM_SYNC_SCOPE.CTA
elif scope_option == "sys":
scope = ir.MEM_SYNC_SCOPE.SYSTEM
else:
raise ValueError(f"Memory semantic {scope_option} not supported")
return scope
def _canonicalize_boundary_check(boundary_check, block_shape):
if boundary_check:
if not hasattr(boundary_check, "__iter__"):
@@ -913,8 +862,8 @@ def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, evicti
boundary_check = _canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes())
# Build IR
return tl.tensor(builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction,
is_volatile), dst_ty)
return tl.tensor(
builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile), dst_ty)
def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder):
@@ -970,19 +919,13 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_
if not mask:
return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty)
else:
return tl.tensor(builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache,
eviction, is_volatile), dst_ty)
return tl.tensor(
builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction,
is_volatile), dst_ty)
def load(ptr: tl.tensor,
mask: Optional[tl.tensor],
other: Optional[tl.tensor],
boundary_check,
padding_option: str,
cache_modifier: str,
eviction_policy: str,
is_volatile: bool,
builder: ir.builder) -> tl.tensor:
def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], boundary_check, padding_option: str,
cache_modifier: str, eviction_policy: str, is_volatile: bool, builder: ir.builder) -> tl.tensor:
# Cache, eviction and padding options
cache = _str_to_load_cache_modifier(cache_modifier)
eviction = _str_to_eviction_policy(eviction_policy)
@@ -1007,7 +950,8 @@ def _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builde
if not val.type.is_block():
val = broadcast_impl_shape(val, block_shape, builder)
assert val.type.is_block(), "Value argument must be block type or a scalar"
assert block_shape == val.type.get_block_shapes(), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch"
assert block_shape == val.type.get_block_shapes(
), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch"
assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch"
elt_ty = ptr.type.element_ty.element_ty
@@ -1065,13 +1009,8 @@ def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder):
return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction), tl.void)
def store(ptr: tl.tensor,
val: tl.tensor,
mask: Optional[tl.tensor],
boundary_check,
cache_modifier: str,
eviction_policy: str,
builder: ir.builder) -> tl.tensor:
def store(ptr: tl.tensor, val: tl.tensor, mask: Optional[tl.tensor], boundary_check, cache_modifier: str,
eviction_policy: str, builder: ir.builder) -> tl.tensor:
# Cache and eviction options
cache = _str_to_store_cache_modifier(cache_modifier)
eviction = _str_to_eviction_policy(eviction_policy)
@@ -1089,22 +1028,16 @@ def store(ptr: tl.tensor,
#########
def atomic_cas(ptr: tl.tensor,
cmp: tl.tensor,
val: tl.tensor,
sem: str,
builder: ir.builder) -> tl.tensor:
def atomic_cas(ptr: tl.tensor, cmp: tl.tensor, val: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
sem = _str_to_sem(sem)
scope = _str_to_scope(scope)
element_ty = ptr.type.scalar.element_ty
if element_ty.primitive_bitwidth not in [16, 32, 64]:
raise ValueError("atomic_cas only supports elements with width {16, 32, 64}")
return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem), val.type)
return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type)
def atom_red_typechecking_impl(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
op: str,
def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, op: str,
builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]:
if not ptr.type.scalar.is_ptr():
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
@@ -1129,30 +1062,19 @@ def atom_red_typechecking_impl(ptr: tl.tensor,
return ptr, val, mask
def atomic_max(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
builder: ir.builder) -> tl.tensor:
def atomic_max(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder)
sem = _str_to_sem(sem)
scope = _str_to_scope(scope)
sca_ty = val.type.scalar
# direct call to atomic_max for integers
if sca_ty.is_int():
if sca_ty.is_int_signed():
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX,
ptr.handle,
val.handle,
mask.handle,
sem),
val.type)
return tl.tensor(
builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
else:
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX,
ptr.handle,
val.handle,
mask.handle,
sem),
val.type)
return tl.tensor(
builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
# ROCM TODO: implement atomic_max/min for f32 as they are supported by MI cards.
# for float
# return atomic_smax(i_ptr, i_val) if val >= 0
@@ -1167,36 +1089,29 @@ def atomic_max(ptr: tl.tensor,
i_ptr = bitcast(ptr, tl.pointer_type(itype, 1), builder)
pos = greater_equal(val, zero, builder)
neg = less_than(val, zero, builder)
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle, sem), i_val.type)
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle, sem), i_val.type)
pos_ret = tl.tensor(
builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle,
and_(mask, pos, builder).handle, sem, scope), i_val.type)
neg_ret = tl.tensor(
builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle,
and_(mask, neg, builder).handle, sem, scope), i_val.type)
ret = where(pos, pos_ret, neg_ret, builder)
return bitcast(ret, sca_ty, builder)
def atomic_min(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
builder: ir.builder) -> tl.tensor:
def atomic_min(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder)
sem = _str_to_sem(sem)
scope = _str_to_scope(scope)
sca_ty = val.type.scalar
# direct call to atomic_min for integers
if sca_ty.is_int():
if sca_ty.is_int_signed():
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN,
ptr.handle,
val.handle,
mask.handle,
sem),
val.type)
return tl.tensor(
builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
else:
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN,
ptr.handle,
val.handle,
mask.handle,
sem),
val.type)
return tl.tensor(
builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
# for float
# return atomic_smin(i_ptr, i_val) if val >= 0
# return atomic_umax(i_ptr, i_val) if val < 0
@@ -1210,72 +1125,57 @@ def atomic_min(ptr: tl.tensor,
i_ptr = bitcast(ptr, tl.pointer_type(itype, 1), builder)
pos = greater_equal(val, zero, builder)
neg = less_than(val, zero, builder)
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN,
i_ptr.handle,
i_val.handle,
and_(mask, pos, builder).handle,
sem),
i_val.type)
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX,
i_ptr.handle,
i_val.handle,
and_(mask, neg, builder).handle,
sem),
i_val.type)
pos_ret = tl.tensor(
builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, i_ptr.handle, i_val.handle,
and_(mask, pos, builder).handle, sem, scope), i_val.type)
neg_ret = tl.tensor(
builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, i_ptr.handle, i_val.handle,
and_(mask, neg, builder).handle, sem, scope), i_val.type)
ret = where(pos, pos_ret, neg_ret, builder)
return bitcast(ret, sca_ty, builder)
def atomic_add(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
builder: ir.builder) -> tl.tensor:
def atomic_add(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder)
sem = _str_to_sem(sem)
scope = _str_to_scope(scope)
sca_ty = val.type.scalar
op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD
return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem), val.type)
return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
def atomic_and(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
builder: ir.builder) -> tl.tensor:
def atomic_and(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder)
sem = _str_to_sem(sem)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem), val.type)
scope = _str_to_scope(scope)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope),
val.type)
def atomic_or(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
builder: ir.builder) -> tl.tensor:
def atomic_or(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder)
sem = _str_to_sem(sem)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem), val.type)
scope = _str_to_scope(scope)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope),
val.type)
def atomic_xor(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
builder: ir.builder) -> tl.tensor:
def atomic_xor(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder)
sem = _str_to_sem(sem)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem), val.type)
scope = _str_to_scope(scope)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope),
val.type)
def atomic_xchg(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
def atomic_xchg(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder)
sem = _str_to_sem(sem)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem), val.type)
scope = _str_to_scope(scope)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope),
val.type)
# ===----------------------------------------------------------------------===//
# Linear Algebra
@@ -1308,13 +1208,10 @@ def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty, target) -> bool:
return False
return True
def dot(lhs: tl.tensor,
rhs: tl.tensor,
acc: tl.tensor,
allow_tf32: bool,
max_num_imprecise_acc: int,
out_dtype: tl.dtype,
builder: ir.builder) -> tl.tensor:
def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_num_imprecise_acc: int,
out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
def assert_dtypes_valid(lhs_dtype, rhs_dtype, target):
# Checks for non-cuda archs
if is_hip():
@@ -1333,22 +1230,30 @@ def dot(lhs: tl.tensor,
# Checks for cuda archs
if target.capability < 90:
assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv(), "Dot op does not support fp8e4nv on CUDA arch < 90"
assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv(
), "Dot op does not support fp8e4nv on CUDA arch < 90"
if lhs_dtype.is_fp8() and rhs_dtype.is_fp8():
return
assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!"
else:
assert not lhs_dtype.is_fp8e4b15() and not rhs_dtype.is_fp8e4b15(), "Dot op does not support fp8e4b15 on CUDA arch >= 90"
assert not lhs_dtype.is_fp8e4b15x4() and not rhs_dtype.is_fp8e4b15x4(), "Dot op does not support fp8e4b15x4 on CUDA arch >= 90"
assert not lhs_dtype.is_fp8e4b15() and not rhs_dtype.is_fp8e4b15(
), "Dot op does not support fp8e4b15 on CUDA arch >= 90"
assert not lhs_dtype.is_fp8e4b15x4() and not rhs_dtype.is_fp8e4b15x4(
), "Dot op does not support fp8e4b15x4 on CUDA arch >= 90"
if lhs_dtype.is_int() or rhs_dtype.is_int():
assert lhs_dtype == rhs_dtype, f"Both operands must be same type. First operand ({lhs_dtype}) and second operand ({rhs_dtype})"
assert lhs_dtype.is_int8() or lhs_dtype.is_uint8(), f"Both operands must be either int8 or uint8. Operand type ({lhs_dtype})"
assert lhs_dtype.is_int8() or lhs_dtype.is_uint8(
), f"Both operands must be either int8 or uint8. Operand type ({lhs_dtype})"
elif lhs_dtype.is_fp8() or rhs_dtype.is_fp8():
assert lhs_dtype.is_fp8e4nv() or lhs_dtype.is_fp8e5(), f"Only supports fp8e4nv or fp8e5. First operand ({lhs_dtype})"
assert rhs_dtype.is_fp8e4nv() or rhs_dtype.is_fp8e5(), f"Only supports fp8e4nv or fp8e5. Second operand ({rhs_dtype})"
assert lhs_dtype.is_fp8e4nv() or lhs_dtype.is_fp8e5(
), f"Only supports fp8e4nv or fp8e5. First operand ({lhs_dtype})"
assert rhs_dtype.is_fp8e4nv() or rhs_dtype.is_fp8e5(
), f"Only supports fp8e4nv or fp8e5. Second operand ({rhs_dtype})"
else:
assert lhs_dtype.is_fp16() or lhs_dtype.is_bf16() or lhs_dtype.is_fp32() or lhs_dtype.is_int1(), f"Unsupported dtype {lhs_dtype}"
assert rhs_dtype.is_fp16() or rhs_dtype.is_bf16() or rhs_dtype.is_fp32() or rhs_dtype.is_int1(), f"Unsupported dtype {rhs_dtype}"
assert lhs_dtype.is_fp16() or lhs_dtype.is_bf16() or lhs_dtype.is_fp32() or lhs_dtype.is_int1(
), f"Unsupported dtype {lhs_dtype}"
assert rhs_dtype.is_fp16() or rhs_dtype.is_bf16() or rhs_dtype.is_fp32() or rhs_dtype.is_int1(
), f"Unsupported dtype {rhs_dtype}"
assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!"
assert lhs.type.is_block() and rhs.type.is_block()
@@ -1389,7 +1294,8 @@ def dot(lhs: tl.tensor,
_0 = builder.get_int32(0)
ret_scalar_ty = tl.int32
elif out_dtype.is_bf16():
raise ValueError("out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`")
raise ValueError(
"out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`")
elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16():
_0 = builder.get_fp32(0)
ret_scalar_ty = tl.float32
@@ -1418,7 +1324,8 @@ def dot(lhs: tl.tensor,
ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32, max_num_imprecise_acc),
ret_ty)
return cast(ret, ret_scalar_ty, builder)
if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty, builder.target) and ret_scalar_ty.primitive_bitwidth <= 32:
if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32,
ret_scalar_ty, builder.target) and ret_scalar_ty.primitive_bitwidth <= 32:
# max_num_imprecise_acc does not yet apply to hip
if is_hip():
max_num_imprecise_acc = 0
@@ -1445,23 +1352,21 @@ def dot(lhs: tl.tensor,
assert acc.type == ret_ty
# max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90
if not (_is_cuda(builder.target) and builder.target.capability == 90 and lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and ret_scalar_ty.is_fp32()):
if not (_is_cuda(builder.target) and builder.target.capability == 90 and lhs.dtype.is_fp8() and rhs.dtype.is_fp8()
and ret_scalar_ty.is_fp32()):
max_num_imprecise_acc = 0
if max_num_imprecise_acc is None:
max_num_imprecise_acc = 2**30
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, allow_tf32, max_num_imprecise_acc),
ret_ty)
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, allow_tf32, max_num_imprecise_acc), ret_ty)
# ===----------------------------------------------------------------------===//
# Indexing
# ===----------------------------------------------------------------------===//
def where(condition: tl.tensor,
x: tl.tensor,
y: tl.tensor,
builder: ir.builder) -> tl.tensor:
def where(condition: tl.tensor, x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor:
condition = cast(condition, tl.int1, builder)
if condition.type.is_block():
condition, x = broadcast_impl_value(condition, x, builder)
@@ -1474,14 +1379,13 @@ def where(condition: tl.tensor,
ret_ty = x.type
return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
# ===----------------------------------------------------------------------===//
# Reduction
# ===----------------------------------------------------------------------===
def reduction(
inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder
) -> Tuple[tl.tensor, ...]:
def reduction(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder) -> Tuple[tl.tensor, ...]:
if axis is None:
new_inputs = []
for i in range(len(inputs)):
@@ -1507,10 +1411,7 @@ def reduction(
region_builder_fn(reduce_op)
reduce_op.verify()
return tuple(
wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar)
for i in range(len(inputs))
)
return tuple(wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar) for i in range(len(inputs)))
# ===----------------------------------------------------------------------===
@@ -1518,9 +1419,8 @@ def reduction(
# ===----------------------------------------------------------------------===
def associative_scan(
inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder
) -> Tuple[tl.tensor, ...]:
def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn,
builder: ir.builder) -> Tuple[tl.tensor, ...]:
if len(inputs) != 1:
raise ValueError("Current implementation only support single tensor input")
shape = inputs[0].type.shape
@@ -1533,16 +1433,14 @@ def associative_scan(
region_builder_fn(scan_op)
scan_op.verify()
return tuple(
wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar)
for i in range(len(inputs))
)
return tuple(wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar) for i in range(len(inputs)))
# ===----------------------------------------------------------------------===
# Math
# ===----------------------------------------------------------------------===
def _check_dtype(dtypes: List[str]) -> T:
"""
We're following libdevice's convention to check accepted data types for math functions.
@@ -1551,7 +1449,9 @@ def _check_dtype(dtypes: List[str]) -> T:
We should let the users know that they are using and invoke explicit cast to convert
the data type to the supported one.
"""
def wrapper(fn):
@wraps(fn)
def check(*args, **kwargs):
# concatenate args and kwargs
@@ -1560,6 +1460,7 @@ def _check_dtype(dtypes: List[str]) -> T:
if arg.type.scalar.name not in dtypes:
raise ValueError(f"Expected dtype {dtypes} but got {arg.type.scalar.name}")
return fn(*args, **kwargs)
return check
return wrapper
@@ -1645,6 +1546,15 @@ def debug_barrier(builder: ir.builder) -> tl.tensor:
def device_print(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor:
# It makes sense visually for prefix to end in ": "; make it so. Also,
# non-empty prefixes should start with " ".
if not prefix.endswith(" ") and args:
prefix += " "
if not prefix.endswith(": ") and args:
prefix = prefix[:-1] + ": "
if len(prefix) > 2 and not prefix.startswith(" "):
prefix = " " + prefix
new_args = []
for arg in args:
new_args.append(arg.handle)
@@ -1654,8 +1564,8 @@ def device_print(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.
def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: int, builder: ir.builder) -> tl.tensor:
cond_ty = cond.type
if not cond_ty.is_block():
cond_ty = tl.block_type(cond_ty.scalar, (1,))
cond = tl.tensor(builder.create_splat(cond.handle, (1,)), cond_ty)
cond_ty = tl.block_type(cond_ty.scalar, (1, ))
cond = tl.tensor(builder.create_splat(cond.handle, (1, )), cond_ty)
return tl.tensor(builder.create_assert(cond.handle, msg, file_name, func_name, lineno), tl.void)

View File

@@ -123,6 +123,7 @@ def maximum(x, y):
"""
return math.max(x, y)
# max and argmax
@@ -149,8 +150,7 @@ def _argmax_combine_tie_break_fast(value1, index1, value2, index2):
@jit
@core._add_reduction_docstr("maximum",
return_indices_arg="return_indices",
@core._add_reduction_docstr("maximum", return_indices_arg="return_indices",
tie_break_arg="return_indices_tie_break_left")
def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True):
input = core._promote_reduction_input(input)
@@ -175,6 +175,7 @@ def argmax(input, axis, tie_break_left=True):
(_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left)
return ret
# min and argmin
@@ -201,8 +202,7 @@ def _argmin_combine_tie_break_fast(value1, index1, value2, index2):
@jit
@core._add_reduction_docstr("minimum",
return_indices_arg="return_indices",
@core._add_reduction_docstr("minimum", return_indices_arg="return_indices",
tie_break_arg="return_indices_tie_break_left")
def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True):
input = core._promote_reduction_input(input)
@@ -222,8 +222,7 @@ def min(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr
@jit
@core._add_reduction_docstr("minimum index",
tie_break_arg="tie_break_left")
@core._add_reduction_docstr("minimum index", tie_break_arg="tie_break_left")
def argmin(input, axis, tie_break_left=True):
_, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left)
return ret
@@ -233,6 +232,7 @@ def argmin(input, axis, tie_break_left=True):
def _sum_combine(a, b):
return a + b
# sum
@@ -247,6 +247,7 @@ def sum(input, axis=None):
def _xor_combine(a, b):
return a ^ b
# xor sum
@@ -258,8 +259,8 @@ def xor_sum(input, axis=None, _builder=None, _generator=None):
raise ValueError("xor_sum only supported for integers")
input = core._promote_reduction_input(input, _builder=_builder)
return core.reduce(input, axis, _xor_combine,
_builder=_builder, _generator=_generator)
return core.reduce(input, axis, _xor_combine, _builder=_builder, _generator=_generator)
# cumsum
@@ -271,6 +272,7 @@ def cumsum(input, axis=0):
input = core._promote_reduction_input(input)
return core.associative_scan(input, axis, _sum_combine)
# cumprod

View File

@@ -17,15 +17,14 @@ from ... import language as tl
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
})
@jit
def _sdd_kernel(
A, B, C,
stride_za, stride_ha, stride_ma, stride_ak,
stride_zb, stride_hb, stride_bk, stride_nb,
stride_zc, stride_hc, stride_mc, stride_nc,
K, grid_offset, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
BLOCK: tl.constexpr, EVEN_K: tl.constexpr
):
def _sdd_kernel(A, B, C, #
stride_za, stride_ha, stride_ma, stride_ak, #
stride_zb, stride_hb, stride_bk, stride_nb, #
stride_zc, stride_hc, stride_mc, stride_nc, #
K, grid_offset, lut, #
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, #
BLOCK: tl.constexpr, EVEN_K: tl.constexpr #
):
# ------------ #
# - Prologue - #
# ------------ #
@@ -104,13 +103,13 @@ def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=
c = out
grid = [c.shape[1], 1, c.shape[0]]
_sdd_kernel[grid](
a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
Ka, 0, lut,
TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4,
num_warps=4,
a, b, c, #
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), #
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), #
c.stride(0), c.stride(1), c.stride(2), c.stride(3), #
Ka, 0, lut, #
TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4, #
num_warps=4 #
)
return c
@@ -120,6 +119,7 @@ def sdd_lut(layout, block, device):
lut = lut.contiguous()
return lut, None
# -----------------------------
# Dense = Sparse x Dense (DSD)
# This operation uses a look-up table that contains pre-computed pointer increments
@@ -128,15 +128,14 @@ def sdd_lut(layout, block, device):
@jit
def _dsd_kernel(
A, B, C,
stride_az, stride_ha, stride_am, stride_ak,
stride_zb, stride_hb, stride_bk, stride_bn,
stride_zc, stride_hc, stride_cm, stride_cn,
DS0, DS1, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr
):
def _dsd_kernel(A, B, C, #
stride_az, stride_ha, stride_am, stride_ak, #
stride_zb, stride_hb, stride_bk, stride_bn, #
stride_zc, stride_hc, stride_cm, stride_cn, #
DS0, DS1, lut, #
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr #
):
# ------------ #
# - Prologue - #
# ------------ #
@@ -229,13 +228,13 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=N
# compute output
grid = lambda meta: [cdiv(BS3, meta['TILE_N']), width, BS0]
_dsd_kernel[grid](
a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
BS3, AS1, lut,
TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
num_warps=4, GROUP_SIZE_M=4,
a, b, c, #
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), #
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), #
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), #
BS3, AS1, lut, #
TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4, #
num_warps=4, GROUP_SIZE_M=4 #
)
# exit()
return c
@@ -337,6 +336,7 @@ def dsd_lut(layout, block, step, trans, device):
# create locks
return lut, width
# -----------------------------
# Dense = Dense x Sparse (DDS)
# -----------------------------
@@ -346,6 +346,7 @@ def dsd_lut(layout, block, step, trans, device):
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
return dsd_matmul(b, a, not trans_b, not trans_a, not trans_c, spdims, block, lut, width, out=out)
##############
# MAIN API #
##############
@@ -356,10 +357,8 @@ class _matmul(torch.autograd.Function):
fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul}
@staticmethod
def forward(
ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block,
c_lut, c_width, da_lut, da_width, db_lut, db_width, out
):
def forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_width, da_lut, da_width, db_lut,
db_width, out):
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out)
# save for backward
ctx.save_for_backward(a, b)
@@ -385,15 +384,13 @@ class _matmul(torch.autograd.Function):
# gradients w.r.t. a
if ctx.needs_input_grad[0]:
mode_da = mode[1] + mode[0] + mode[2]
da = _matmul.fn[mode_da](
dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_width,
)
da = _matmul.fn[mode_da](dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block,
ctx.da_lut, ctx.da_width)
# gradients w.r.t. b
if ctx.needs_input_grad[1]:
mode_db = mode[2] + mode[1] + mode[0]
db = _matmul.fn[mode_db](
a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_width,
)
db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block,
ctx.db_lut, ctx.db_width)
dout = dc if ctx.has_out else None
return da, db, None, None, None, \
None, None, None, None, \
@@ -427,11 +424,9 @@ class matmul:
self.db_lut, self.db_width = sdd_lut(layout, block, device)
def __call__(self, a, b, out=None):
c = _matmul.apply(
a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block,
self.c_lut, self.c_width,
self.da_lut, self.da_width,
self.db_lut, self.db_width,
out
)
c = _matmul.apply(a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, #
self.c_lut, self.c_width, #
self.da_lut, self.da_width, #
self.db_lut, self.db_width, #
out)
return c

View File

@@ -18,14 +18,13 @@ def num_warps(n):
@jit
def _blocksparse_softmax_fwd(
Out, A, stride_xz, LUT,
R, extent, stride_zr, stride_hr, # relative attention
scale, is_causal,
ROW_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
IS_DENSE: tl.constexpr,
):
def _blocksparse_softmax_fwd(Out, A, stride_xz, LUT, #
R, extent, stride_zr, stride_hr, # relative attention
scale, is_causal, #
ROW_SIZE: tl.constexpr, #
BLOCK_SIZE: tl.constexpr, #
IS_DENSE: tl.constexpr #
):
h = tl.program_id(0)
m = tl.program_id(1)
z = tl.program_id(2)
@@ -73,18 +72,16 @@ def _blocksparse_softmax_fwd(
@jit
def _blocksparse_softmax_bwd(
DA, stride_zdx,
DOut, stride_zdout,
Out, stride_zout,
scale,
LUT,
DR, extent, stride_zr, stride_hr, stride_er,
is_causal,
ROW_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
IS_DENSE: tl.constexpr,
):
def _blocksparse_softmax_bwd(DA, stride_zdx, #
DOut, stride_zdout, #
Out, stride_zout, #
scale, #
LUT, #
DR, extent, stride_zr, stride_hr, stride_er, #
is_causal, #
ROW_SIZE: tl.constexpr, #
BLOCK_SIZE: tl.constexpr, #
IS_DENSE: tl.constexpr):
h = tl.program_id(0)
m = tl.program_id(1)
z = tl.program_id(2)
@@ -133,6 +130,7 @@ def _blocksparse_softmax_bwd(
class _softmax(torch.autograd.Function):
@staticmethod
def make_lut(layout, block, device):
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
@@ -151,10 +149,7 @@ class _softmax(torch.autograd.Function):
return lut, int(total_sizes.max())
@staticmethod
def forward(
ctx, a, scale, rel_logits, is_causal,
spdims, block, lut, maxlut, is_dense
):
def forward(ctx, a, scale, rel_logits, is_causal, spdims, block, lut, maxlut, is_dense):
if scale is not None and isinstance(scale, torch.Tensor):
assert scale.device.type == "cpu"
scale = scale.item()
@@ -165,14 +160,14 @@ class _softmax(torch.autograd.Function):
# enqueue kernel
out = torch.empty_like(a)
_blocksparse_softmax_fwd[grid](
out, a, a.stride(0), lut,
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn
scale,
is_causal,
BLOCK_SIZE=block,
ROW_SIZE=next_power_of_2(maxlut),
IS_DENSE=is_dense,
num_warps=num_warps(maxlut)
out, a, a.stride(0), lut, #
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn#
scale, #
is_causal, #
BLOCK_SIZE=block, #
ROW_SIZE=next_power_of_2(maxlut), #
IS_DENSE=is_dense, #
num_warps=num_warps(maxlut) #
)
# save to context
# ctx.mark_dirty(x)
@@ -201,28 +196,23 @@ class _softmax(torch.autograd.Function):
grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M)
da = torch.empty_like(dout)
_blocksparse_softmax_bwd[grid](
da, da.stride(0),
dout, dout.stride(0),
out, out.stride(0),
ctx.scale,
lut,
dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2],
ctx.is_causal,
BLOCK_SIZE=ctx.block,
ROW_SIZE=next_power_of_2(ctx.maxlut),
IS_DENSE=ctx.is_dense,
num_warps=num_warps(ctx.maxlut)
da, da.stride(0), #
dout, dout.stride(0), #
out, out.stride(0), #
ctx.scale, #
lut, #
dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2], #
ctx.is_causal, #
BLOCK_SIZE=ctx.block, #
ROW_SIZE=next_power_of_2(ctx.maxlut), #
IS_DENSE=ctx.is_dense, #
num_warps=num_warps(ctx.maxlut) #
)
return (da, None, None, dr, None,
None, None, None, None, None,
None,
None, None, None,
None,
None, None, None
)
return (da, None, None, dr, None, None, None, None, None, None, None, None, None, None, None, None, None, None)
class softmax:
def __init__(self, layout, block, device, is_dense=False):
self.spdims = layout.shape
self.layout = layout
@@ -233,8 +223,6 @@ class softmax:
def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False):
if rel_logits is not None and rel_logits.dtype != a.dtype:
raise ValueError(f"relative position embedding must be {a.dtype}")
a = _softmax.apply(
a, scale, rel_logits, is_causal,
self.spdims, self.block, self.lut, self.maxlut, self.is_dense,
)
a = _softmax.apply(a, scale, rel_logits, is_causal, self.spdims, self.block, self.lut, self.maxlut,
self.is_dense)
return a

View File

@@ -59,6 +59,7 @@ def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
class _cross_entropy(torch.autograd.Function):
@classmethod
def forward(cls, ctx, logits, indices):
# make sure we can use triton

View File

@@ -15,20 +15,19 @@ from .. import language as tl
@jit
def _fwd_kernel(
Q, K, V, sm_scale,
L,
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
stride_oz, stride_oh, stride_om, stride_on,
Z, H, N_CTX,
Z_H_N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
IS_CAUSAL: tl.constexpr,
):
def _fwd_kernel(Q, K, V, sm_scale, #
L, #
Out, #
stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vn, stride_vk, #
stride_oz, stride_oh, stride_om, stride_on, #
Z, H, N_CTX, #
Z_H_N_CTX, #
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
BLOCK_N: tl.constexpr, #
IS_CAUSAL: tl.constexpr #
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
qvk_offset = off_hz * stride_qh
@@ -40,7 +39,7 @@ def _fwd_kernel(
strides=(stride_kk, stride_kn),
offsets=(0, vk_offset),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
order=(0, 1),
)
V_block_ptr = tl.make_block_ptr(
base=V,
@@ -48,7 +47,7 @@ def _fwd_kernel(
strides=(stride_vn, stride_vk),
offsets=(vk_offset, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
order=(1, 0),
)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
@@ -104,7 +103,7 @@ def _fwd_kernel(
strides=(stride_om, stride_on),
offsets=(vk_offset + start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
order=(1, 0),
)
# O_ptrs = Out + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk
tl.store(O_block_ptr, acc.to(K.dtype.element_ty))
@@ -112,9 +111,11 @@ def _fwd_kernel(
@jit
def _bwd_preprocess(
Out, DO,
Out,
DO,
Delta,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
BLOCK_M: tl.constexpr,
D_HEAD: tl.constexpr,
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD)
@@ -128,40 +129,48 @@ def _bwd_preprocess(
@jit
def _bwd_kernel_one_col_block(
Q, K, V, sm_scale, qk_scale,
Out, DO,
DQ, DK, DV,
L,
D,
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
Z, H, N_CTX,
off_hz, start_n, num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr,
CAUSAL: tl.constexpr,
MMA_V3: tl.constexpr
):
if SEQUENCE_PARALLEL:
DQ += stride_dqa.to(tl.int64) * start_n
def _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, #
Out, DO, #
DQ, DK, DV, #
L, #
D, #
Q_block_ptr, K_block_ptr, V_block_ptr, #
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vn, stride_vk, #
Z, H, N_CTX, #
off_h, off_z, off_hz, start_n, num_block, #
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
BLOCK_N: tl.constexpr, #
SEQUENCE_PARALLEL: tl.constexpr, #
CAUSAL: tl.constexpr, #
MMA_V3: tl.constexpr #
):
if CAUSAL:
lo = start_n * BLOCK_M
else:
lo = 0
Q_offset = (off_z * stride_qz + off_h * stride_qh) // stride_qm
DQ_offset = off_z * stride_qz + off_h * stride_qh
K_offset = (off_z * stride_kz + off_h * stride_kh) // stride_kn
V_offset = (off_z * stride_vz + off_h * stride_vh) // stride_vn
if SEQUENCE_PARALLEL:
DQ_offset += stride_dqa.to(tl.int64) * start_n
DQ_offset = DQ_offset // stride_qm
Q_block_ptr = tl.advance(Q_block_ptr, (lo + Q_offset, 0))
K_block_ptr = tl.advance(K_block_ptr, (start_n * BLOCK_M + K_offset, 0))
V_block_ptr = tl.advance(V_block_ptr, (start_n * BLOCK_M + V_offset, 0))
DO_block_ptr = tl.advance(DO_block_ptr, (lo + Q_offset, 0))
DQ_block_ptr = tl.advance(DQ_block_ptr, (lo + DQ_offset, 0))
DK_block_ptr = tl.advance(DK_block_ptr, (start_n * BLOCK_M + K_offset, 0))
DV_block_ptr = tl.advance(DV_block_ptr, (start_n * BLOCK_M + V_offset, 0))
# initialize row/col offsets
offs_qm = lo + tl.arange(0, BLOCK_M)
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
offs_m = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_DMODEL)
# initialize pointers to value-like data
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk)
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
l_ptrs = L + off_hz * N_CTX
@@ -169,17 +178,17 @@ def _bwd_kernel_one_col_block(
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# k and v stay in SRAM throughout
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
k = tl.load(K_block_ptr)
v = tl.load(V_block_ptr)
# loop over rows
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip
q = tl.load(q_ptrs)
q = tl.load(Q_block_ptr)
# recompute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here
if CAUSAL:
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf"))
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.0), float("-inf"))
else:
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
@@ -187,7 +196,7 @@ def _bwd_kernel_one_col_block(
l_i = tl.load(l_ptrs + offs_m_curr)
p = tl.math.exp2(qk - l_i[:, None])
# compute dv
do = tl.load(do_ptrs)
do = tl.load(DO_block_ptr)
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do, allow_tf32=True)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
@@ -199,97 +208,156 @@ def _bwd_kernel_one_col_block(
dk += tl.dot(tl.trans(ds), q, allow_tf32=True)
# compute dq
if not SEQUENCE_PARALLEL:
dq = tl.load(dq_ptrs)
dq = tl.load(DQ_block_ptr)
dq += tl.dot(ds, k, allow_tf32=True)
tl.store(dq_ptrs, dq)
tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty))
elif SEQUENCE_PARALLEL:
if MMA_V3:
dq = tl.dot(ds, k, allow_tf32=True)
else:
# not work with mma v3, becuase M % 64 != 0
dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds), allow_tf32=True))
tl.store(dq_ptrs, dq)
tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty))
# increment pointers
dq_ptrs += BLOCK_M * stride_qm
q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_qm
DQ_block_ptr = tl.advance(DQ_block_ptr, (BLOCK_M, 0))
Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_M, 0))
DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_M, 0))
# write-back
dv_ptrs = DV + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk)
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
tl.store(DV_block_ptr, dv.to(V.dtype.element_ty))
tl.store(DK_block_ptr, dk.to(K.dtype.element_ty))
@jit
def _bwd_kernel(
# fmt: off
Q, K, V, sm_scale,
Out, DO,
DQ, DK, DV,
L,
D,
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr,
CAUSAL: tl.constexpr,
MMA_V3: tl.constexpr
# fmt: on
):
def _bwd_kernel(Q, K, V, sm_scale, #
Out, DO, #
DQ, DK, DV, #
L, #
D, #
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vn, stride_vk, #
Z, H, N_CTX, #
Z_H_N_CTX, #
SQ_Z_H_N_CTX, #
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
BLOCK_N: tl.constexpr, #
SEQUENCE_PARALLEL: tl.constexpr, #
CAUSAL: tl.constexpr, #
MMA_V3: tl.constexpr #
):
qk_scale = sm_scale * 1.44269504
off_hz = tl.program_id(0)
off_z = off_hz // H
off_h = off_hz % H
# offset pointers for batch/head
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_kz + off_h * stride_kh
V += off_z * stride_vz + off_h * stride_vh
DO += off_z * stride_qz + off_h * stride_qh
DQ += off_z * stride_qz + off_h * stride_qh
DK += off_z * stride_kz + off_h * stride_kh
DV += off_z * stride_vz + off_h * stride_vh
Q_block_ptr = tl.make_block_ptr(
base=Q,
shape=(Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
K_block_ptr = tl.make_block_ptr(
base=K,
shape=(Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_kn, stride_kk),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
V_block_ptr = tl.make_block_ptr(
base=V,
shape=(Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_vn, stride_vk),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
DO_block_ptr = tl.make_block_ptr(
base=DO,
shape=(Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
if SEQUENCE_PARALLEL:
DQ_block_ptr = tl.make_block_ptr(
base=DQ,
shape=(SQ_Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
else:
DQ_block_ptr = tl.make_block_ptr(
base=DQ,
shape=(Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
DK_block_ptr = tl.make_block_ptr(
base=DK,
shape=(Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_kn, stride_kk),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
DV_block_ptr = tl.make_block_ptr(
base=DV,
shape=(Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_vn, stride_vk),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
num_block_n = tl.cdiv(N_CTX, BLOCK_N)
if not SEQUENCE_PARALLEL:
for start_n in range(0, num_block_n):
_bwd_kernel_one_col_block(
Q, K, V, sm_scale, qk_scale, Out, DO,
DQ, DK, DV,
L,
D,
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
Z, H, N_CTX,
off_hz, start_n, num_block_n,
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_N=BLOCK_N,
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
CAUSAL=CAUSAL,
MMA_V3=MMA_V3
)
_bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, #
DQ, DK, DV, #
L, #
D, #
Q_block_ptr, K_block_ptr, V_block_ptr, #
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vn, stride_vk, #
Z, H, N_CTX, #
off_h, off_z, off_hz, start_n, num_block_n, #
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, #
BLOCK_N=BLOCK_N, #
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, #
CAUSAL=CAUSAL, #
MMA_V3=MMA_V3 #
)
else:
start_n = tl.program_id(1)
_bwd_kernel_one_col_block(
Q, K, V, sm_scale, qk_scale, Out, DO,
DQ, DK, DV,
L,
D,
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
Z, H, N_CTX,
off_hz, start_n, num_block_n,
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_N=BLOCK_N,
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
CAUSAL=CAUSAL,
MMA_V3=MMA_V3
)
_bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, #
DQ, DK, DV, #
L, #
D, #
Q_block_ptr, K_block_ptr, V_block_ptr, #
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vn, stride_vk, #
Z, H, N_CTX, #
off_h, off_z, off_hz, start_n, num_block_n, #
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, #
BLOCK_N=BLOCK_N, #
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, #
CAUSAL=CAUSAL, #
MMA_V3=MMA_V3 #
)
class _attention(torch.autograd.Function):
@@ -315,19 +383,20 @@ class _attention(torch.autograd.Function):
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
_fwd_kernel[grid](
q, k, v, sm_scale,
L,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
q.shape[0] * q.shape[1] * q.shape[2],
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
IS_CAUSAL=causal,
num_warps=num_warps,
num_stages=4)
q, k, v, sm_scale, #
L, #
o, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
q.shape[0], q.shape[1], q.shape[2], #
q.shape[0] * q.shape[1] * q.shape[2], #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, #
IS_CAUSAL=causal, #
num_warps=num_warps, #
num_stages=4 #
)
ctx.save_for_backward(q, k, v, o, L)
ctx.grid = grid
@@ -348,35 +417,39 @@ class _attention(torch.autograd.Function):
do = do.contiguous()
if sequence_parallel:
replicas = cdiv(seq_len_kv, BLOCK)
new_dq_shape = (replicas,) + q.shape
new_dq_shape = (replicas, ) + q.shape
dq = torch.zeros(new_dq_shape, device=q.device, dtype=q.dtype)
else:
dq = torch.zeros_like(q, dtype=torch.float32)
dq = torch.zeros_like(q, dtype=q.dtype)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
delta = torch.empty_like(L)
_bwd_preprocess[(cdiv(q.shape[2], BLOCK) * ctx.grid[1], )](
o, do,
o,
do,
delta,
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
BLOCK_M=BLOCK,
D_HEAD=ctx.BLOCK_DMODEL,
)
_bwd_kernel[(ctx.grid[1], cdiv(seq_len_kv, BLOCK) if sequence_parallel else 1)](
q, k, v, ctx.sm_scale,
o, do,
dq, dk, dv,
L,
delta,
o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
SEQUENCE_PARALLEL=sequence_parallel,
CAUSAL=ctx.causal,
MMA_V3=MMA_V3,
num_warps=8,
num_stages=1,
q, k, v, ctx.sm_scale, #
o, do, #
dq, dk, dv, #
L, #
delta, #
o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
q.shape[0], q.shape[1], q.shape[2], #
q.shape[0] * q.shape[1] * q.shape[2], #
cdiv(seq_len_kv, BLOCK) * q.shape[0] * q.shape[1] * q.shape[2], #
BLOCK_M=BLOCK, BLOCK_N=BLOCK, #
BLOCK_DMODEL=ctx.BLOCK_DMODEL, #
SEQUENCE_PARALLEL=sequence_parallel, #
CAUSAL=ctx.causal, #
MMA_V3=MMA_V3, #
num_warps=8, #
num_stages=1 #
)
if len(dq.shape) == 5:

View File

@@ -37,8 +37,9 @@ def get_configs_io_bound():
num_stages=num_stages, num_warps=num_warps))
# split_k
for split_k in [2, 4, 8, 16]:
configs.append(Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
configs.append(
Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return configs
@@ -69,22 +70,22 @@ def get_configs_io_bound():
prune_configs_by={
'early_config_prune': early_config_prune,
'perf_model': estimate_matmul_time,
'top_k': 10
'top_k': 10,
},
)
@heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@jit
def _kernel(A, B, C, M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
dot_out_dtype: tl.constexpr,
allow_tf32: tl.constexpr,
fp8_fast_accum: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr
def _kernel(A, B, C, M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
dot_out_dtype: tl.constexpr, #
allow_tf32: tl.constexpr, #
fp8_fast_accum: tl.constexpr, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr #
):
# matrix multiplication
pid = tl.program_id(0)
@@ -184,14 +185,15 @@ class _matmul(torch.autograd.Function):
ab_dtype = False
# launch kernel
grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
dot_out_dtype=dot_out_dtype,
allow_tf32=allow_tf32,
fp8_fast_accum=fp8_fast_accum,
GROUP_M=8, AB_DTYPE=ab_dtype)
_kernel[grid](
a, b, c, M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
dot_out_dtype=dot_out_dtype, #
allow_tf32=allow_tf32, #
fp8_fast_accum=fp8_fast_accum, #
GROUP_M=8, AB_DTYPE=ab_dtype)
return c
@staticmethod

View File

@@ -5,8 +5,7 @@ import torch
from .. import cdiv
from .._C.libtriton.triton import runtime
from ..runtime import driver
from ..testing import (get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops,
nvsmi)
from ..testing import (get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops, nvsmi)
def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
@@ -14,7 +13,8 @@ def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
total_warps = num_ctas * min(num_warps, 4)
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, cur_sm_clock, backend, device)
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(
dtype, cur_sm_clock, backend, device)
return tflops
@@ -35,12 +35,12 @@ def get_tflops(backend, device, num_ctas, num_warps, dtype):
def estimate_matmul_time(
# backend, device,
num_warps, num_stages,
A, B, C,
M, N, K,
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K,
debug=False, **kwargs
# backend, device,
num_warps, num_stages, #
A, B, C, #
M, N, K, #
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, #
debug=False, **kwargs #
):
''' return estimated running time in ms
= max(compute, loading) + store '''
@@ -149,8 +149,9 @@ def early_config_prune(configs, named_args):
optimal_num_stages = ldgsts_latency / mma_cycles
# nearest stages, prefer large #stages
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages)
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
nearest = heapq.nsmallest(
2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages)
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
for n in nearest:
pruned_configs.append(n[0])

View File

@@ -1,8 +1,6 @@
from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune,
heuristics)
from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune, heuristics)
from .driver import driver
from .jit import (JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret,
version_key)
from .jit import JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret
__all__ = [
"driver",
@@ -12,7 +10,6 @@ __all__ = [
"heuristics",
"JITFunction",
"KernelInterface",
"version_key",
"reinterpret",
"TensorWrapper",
"OutOfResources",

View File

@@ -9,11 +9,10 @@ from .jit import KernelInterface
class OutOfResources(Exception):
def __init__(self, required, limit, name):
self.message = f'out of resource: {name}, '\
f'Required: {required}, '\
f'Hardware limit: {limit}'
self.message += '. Reducing block sizes or `num_stages` may help.'
self.message = (f"out of resource: {name}, Required: {required}, Hardware limit: {limit}. " +
"Reducing block sizes or `num_stages` may help.")
self.required = required
self.limit = limit
self.name = name
@@ -25,38 +24,73 @@ class OutOfResources(Exception):
class Autotuner(KernelInterface):
def __init__(self, fn, arg_names, configs, key, verbose, reset_to_zero, prune_configs_by: Dict = None, warmup=25, rep=100):
'''
def __init__(
self,
fn,
arg_names,
configs,
key,
verbose,
reset_to_zero,
restore_value,
prune_configs_by: Dict = None,
warmup=25,
rep=100,
):
"""
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
'''
"""
if not configs:
self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)]
else:
self.configs = configs
self.key_idx = [arg_names.index(k) for k in key]
self.cache = {}
# hook to reset all required tensor to zeros before relaunching a kernel
self.hook = lambda args: 0
self.arg_names = arg_names
# Reset to zero or restore values
self.reset_idx = []
if reset_to_zero is not None:
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
self.restore_idx = []
if restore_value is not None:
self.restore_idx = [arg_names.index(k) for k in restore_value]
def _hook(args):
# Hook to reset or restore for required tensors
self.pre_hook = lambda args, reset_only=False: 0
self.post_hook = lambda args: 0
if len(self.reset_idx) > 0 or len(self.restore_idx) > 0:
def _pre_hook(args, reset_only=False):
for i in self.reset_idx:
args[i].zero_()
self.hook = _hook
self.arg_names = arg_names
# prune configs
if not reset_only:
self.restore_copies = [args[i].clone() for i in self.restore_idx]
self.pre_hook = _pre_hook
if len(self.restore_idx) > 0:
def _post_hook(args):
for i, j in enumerate(self.restore_idx):
args[j].copy_(self.restore_copies[i])
self.restore_copies = []
self.post_hook = _post_hook
# Prune configs
if prune_configs_by:
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
if 'early_config_prune' in prune_configs_by:
early_config_prune = prune_configs_by['early_config_prune']
perf_model, top_k = prune_configs_by["perf_model"], prune_configs_by["top_k"]
if "early_config_prune" in prune_configs_by:
early_config_prune = prune_configs_by["early_config_prune"]
else:
perf_model, top_k, early_config_prune = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k
self.early_config_prune = early_config_prune
self.fn = fn
self.warmup = warmup
self.rep = rep
@@ -67,10 +101,8 @@ class Autotuner(KernelInterface):
# as kwargs and by the autotuner
conflicts = meta.keys() & config.kwargs.keys()
if conflicts:
raise ValueError(
f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols."
)
raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols.")
# augment meta-parameters with tunable ones
current = dict(meta, **config.kwargs)
full_nargs = {**self.nargs, **current}
@@ -78,16 +110,22 @@ class Autotuner(KernelInterface):
def kernel_call():
if config.pre_hook:
config.pre_hook(full_nargs)
self.hook(args)
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization,
# enable_persistent=False,
**current)
self.pre_hook(args)
self.fn.run(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization,
# enable_persistent=False,
**current,
)
self.post_hook(args)
try:
return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
except OutOfResources:
return [float('inf'), float('inf'), float('inf')]
return [float("inf"), float("inf"), float("inf")]
def get_best_config(self):
return self.best_config
@@ -110,12 +148,11 @@ class Autotuner(KernelInterface):
# prune configs
pruned_configs = self.prune_configs(kwargs)
bench_start = time.time()
timings = {config: self._bench(*args, config=config, **kwargs)
for config in pruned_configs}
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
bench_end = time.time()
self.bench_time = bench_end - bench_start
self.cache[key] = builtins.min(timings, key=timings.get)
self.hook(args)
self.pre_hook(args, reset_only=True)
self.configs_timings = timings
if self.verbose:
print(str(key) + ": " + str(self.cache[key]))
@@ -126,9 +163,15 @@ class Autotuner(KernelInterface):
full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs}
if config.pre_hook is not None:
config.pre_hook(full_nargs)
ret = self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization, **kwargs, **config.kwargs)
ret = self.fn.run(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization,
**kwargs,
**config.kwargs,
)
self.nargs = None
return ret
@@ -142,17 +185,20 @@ class Autotuner(KernelInterface):
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {
config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages,
num_warps=config.num_warps,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization,
enable_persistent=config.enable_persistent)
config:
self.perf_model(
**self.nargs,
**kwargs,
**config.kwargs,
num_stages=config.num_stages,
num_warps=config.num_warps,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization,
enable_persistent=config.enable_persistent,
)
for config in pruned_configs
}
pruned_configs = sorted(
est_timing.keys(),
key=lambda x: est_timing[x])[
:top_k]
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
return pruned_configs
def warmup(self, *args, **kwargs):
@@ -195,7 +241,7 @@ class Config:
self.num_ctas = num_ctas
self.num_stages = num_stages
self.enable_warp_specialization = enable_warp_specialization
# TODO[shuhaoj]: May make enable_persistent configurable in future if necessay.
# TODO[shuhaoj]: May make enable_persistent configurable in future if necessary.
self.enable_persistent = False
self.pre_hook = pre_hook
@@ -207,13 +253,12 @@ class Config:
## Comment out Hopper specific parameters
#res.append(f'num_ctas: {self.num_ctas}')
res.append(f'num_stages: {self.num_stages}')
#res.append(
# f'enable_warp_specialization: {self.enable_warp_specialization}')
#res.append(f'enable_warp_specialization: {self.enable_warp_specialization}')
#res.append(f'enable_persistent: {self.enable_persistent}')
return ', '.join(res)
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=False, warmup=25, rep=100):
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, verbose=False, warmup=25, rep=100):
"""
Decorator for auto-tuning a :code:`triton.jit`'d function.
@@ -244,6 +289,8 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=Fa
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str]
:param restore_value: a list of argument names whose value will be restored after evaluating any configs.
:type restore_value: list[str]
:param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25.
:type warmup: int
:param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100.
@@ -251,8 +298,9 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=Fa
:param verbose: a boolean that controls whether the best_config for each key is printed
:type verbose: bool
"""
def decorator(fn):
return Autotuner(fn, fn.arg_names, configs, key, verbose, reset_to_zero, prune_configs_by, warmup, rep)
return Autotuner(fn, fn.arg_names, configs, key, verbose, reset_to_zero, restore_value, prune_configs_by, warmup, rep)
return decorator
@@ -286,6 +334,7 @@ def heuristics(values):
each such function takes a list of positional arguments as input.
:type values: dict[str, Callable[[list[Any]], Any]]
"""
def decorator(fn):
return Heuristics(fn, fn.arg_names, values)

View File

@@ -1,27 +1,42 @@
#include "cuda.h"
#include <dlfcn.h>
#include <stdbool.h>
#define PY_SSIZE_T_CLEAN
#include <Python.h>
static inline void gpuAssert(CUresult code, const char *file, int line) {
if (code != CUDA_SUCCESS) {
const char *prefix = "Triton Error [CUDA]: ";
const char *str;
cuGetErrorString(code, &str);
char err[1024] = {0};
strcat(err, prefix);
strcat(err, str);
PyGILState_STATE gil_state;
gil_state = PyGILState_Ensure();
PyErr_SetString(PyExc_RuntimeError, err);
PyGILState_Release(gil_state);
}
// Raises a Python exception and returns false if code is not CUDA_SUCCESS.
static bool gpuAssert(CUresult code, const char *file, int line) {
if (code == CUDA_SUCCESS)
return true;
const char *prefix = "Triton Error [CUDA]: ";
const char *str;
cuGetErrorString(code, &str);
char err[1024] = {0};
strcat(err, prefix);
strcat(err, str);
PyGILState_STATE gil_state;
gil_state = PyGILState_Ensure();
PyErr_SetString(PyExc_RuntimeError, err);
PyGILState_Release(gil_state);
return false;
}
#define CUDA_CHECK(ans) \
{ \
{ gpuAssert((ans), __FILE__, __LINE__); } \
}
// To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block.
#define CUDA_CHECK_AND_RETURN_NULL(ans) \
do { \
if (!gpuAssert((ans), __FILE__, __LINE__)) \
return NULL; \
} while (0)
// To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block.
#define CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(ans) \
do { \
if (!gpuAssert((ans), __FILE__, __LINE__)) { \
PyEval_RestoreThread(_save); \
return NULL; \
} \
} while (0)
#define ADD_ENUM_ITEM(value) \
do { \
@@ -200,16 +215,16 @@ static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
int sm_clock_rate;
int mem_clock_rate;
int mem_bus_width;
CUDA_CHECK(cuDeviceGetAttribute(
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
&max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
device));
CUDA_CHECK(cuDeviceGetAttribute(
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
&multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate,
CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
CUDA_CHECK(cuDeviceGetAttribute(
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
&sm_clock_rate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
&mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));
CUDA_CHECK(cuDeviceGetAttribute(
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
&mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device));
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem",
@@ -237,33 +252,37 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
CUcontext pctx = 0;
Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK(cuCtxGetCurrent(&pctx));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&pctx));
if (!pctx) {
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
CUDA_CHECK(cuCtxSetCurrent(pctx));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuDevicePrimaryCtxRetain(&pctx, device));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(pctx));
}
CUDA_CHECK(cuModuleLoadData(&mod, data));
CUDA_CHECK(cuModuleGetFunction(&fun, mod, name));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuModuleLoadData(&mod, data));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuModuleGetFunction(&fun, mod, name));
// get allocated registers and spilled registers from the function
CUDA_CHECK(cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
CUDA_CHECK(
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
n_spills /= 4;
// set dynamic shared memory if necessary
int shared_optin;
CUDA_CHECK(cuDeviceGetAttribute(
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute(
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
device));
if (shared > 49152 && shared_optin > 49152) {
CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
int shared_total, shared_static;
CUDA_CHECK(cuDeviceGetAttribute(
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute(
&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
device));
CUDA_CHECK(cuFuncGetAttribute(&shared_static,
CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
CUDA_CHECK(
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute(
&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_optin - shared_static));
}
@@ -286,7 +305,7 @@ static PyObject *memAlloc(PyObject *self, PyObject *args) {
}
Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK(cuMemAlloc(&dptr, bytesize));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuMemAlloc(&dptr, bytesize));
Py_END_ALLOW_THREADS;
return PyLong_FromUnsignedLongLong((unsigned long long)dptr);
@@ -307,7 +326,8 @@ static PyObject *memcpyHtoD(PyObject *self, PyObject *args) {
srcHost = (const void *)srcHostPtr;
Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK(cuMemcpyHtoD(dstDevice, srcHost, byteCount));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuMemcpyHtoD(dstDevice, srcHost, byteCount));
Py_END_ALLOW_THREADS;
Py_RETURN_NONE;
@@ -321,7 +341,7 @@ static PyObject *memFree(PyObject *self, PyObject *args) {
}
Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK(cuMemFree(dptr));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuMemFree(dptr));
Py_END_ALLOW_THREADS;
Py_RETURN_NONE;
@@ -411,7 +431,7 @@ static PyObject *tensorMapEncodeTiled(PyObject *self, PyObject *args) {
}
// Call the function
Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK(cuTensorMapEncodeTiledHandle(
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuTensorMapEncodeTiledHandle(
tensorMap, tensorDataType, tensorRank, globalAddress, globalDim,
globalStrides, boxDim, elementStrides, interleave, swizzle, l2Promotion,
oobFill));

View File

@@ -19,6 +19,7 @@ def default_dump_dir():
class CacheManager(ABC):
def __init__(self, key):
pass
@@ -44,20 +45,21 @@ class CacheManager(ABC):
class FileCacheManager(CacheManager):
def __init__(self, key, override=False, dump=False):
self.key = key
self.lock_path = None
if (dump):
if dump:
self.cache_dir = default_dump_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
elif (override):
elif override:
self.cache_dir = default_override_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
else:
# create cache directory if it doesn't exist
self.cache_dir = os.getenv('TRITON_CACHE_DIR', "").strip() or default_cache_dir()
self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
if self.cache_dir:
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
@@ -93,9 +95,8 @@ class FileCacheManager(CacheManager):
result = {}
for c in child_paths:
p = self._make_path(c)
if not os.path.exists(p):
raise Exception(f"Group file {p} does not exist from group {grp_filename} ")
result[c] = p
if os.path.exists(p):
result[c] = p
return result
# Note a group of pushed files as being part of a group
@@ -142,6 +143,7 @@ def get_cache_manager(key) -> CacheManager:
if user_cache_manager is not None and user_cache_manager != __cache_cls_nme:
import importlib
module_path, clz_nme = user_cache_manager.split(":")
module = importlib.import_module(module_path)
__cache_cls = getattr(module, clz_nme)

View File

@@ -9,7 +9,6 @@ from .cache import get_cache_manager
class DriverBase(metaclass=abc.ABCMeta):
CUDA = 0
HIP = 1
@@ -19,6 +18,8 @@ class DriverBase(metaclass=abc.ABCMeta):
def __init__(self) -> None:
pass
# -----------------------------
# CUDA
# -----------------------------
@@ -27,7 +28,7 @@ class DriverBase(metaclass=abc.ABCMeta):
class CudaUtils(object):
def __new__(cls):
if not hasattr(cls, 'instance'):
if not hasattr(cls, "instance"):
cls.instance = super(CudaUtils, cls).__new__(cls)
return cls.instance
@@ -47,6 +48,7 @@ class CudaUtils(object):
with open(so, "rb") as f:
cache_path = cache.put(f.read(), fname, binary=True)
import importlib.util
spec = importlib.util.spec_from_file_location("cuda_utils", cache_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
@@ -66,7 +68,7 @@ class CudaUtils(object):
class CudaDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, 'instance'):
if not hasattr(cls, "instance"):
cls.instance = super(CudaDriver, cls).__new__(cls)
return cls.instance
@@ -74,14 +76,16 @@ class CudaDriver(DriverBase):
self.utils = CudaUtils()
self.backend = self.CUDA
# -----------------------------
# HIP
# -----------------------------
class HIPUtils(object):
def __new__(cls):
if not hasattr(cls, 'instance'):
if not hasattr(cls, "instance"):
cls.instance = super(HIPUtils, cls).__new__(cls)
return cls.instance
@@ -101,6 +105,7 @@ class HIPUtils(object):
with open(so, "rb") as f:
cache_path = cache.put(f.read(), fname, binary=True)
import importlib.util
spec = importlib.util.spec_from_file_location("hip_utils", cache_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
@@ -111,7 +116,7 @@ class HIPUtils(object):
class HIPDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, 'instance'):
if not hasattr(cls, "instance"):
cls.instance = super(HIPDriver, cls).__new__(cls)
return cls.instance
@@ -123,7 +128,7 @@ class HIPDriver(DriverBase):
class UnsupportedDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, 'instance'):
if not hasattr(cls, "instance"):
cls.instance = super(UnsupportedDriver, cls).__new__(cls)
return cls.instance
@@ -131,12 +136,14 @@ class UnsupportedDriver(DriverBase):
self.utils = None
self.backend = None
# -----------------------------
# Driver
# -----------------------------
class LazyProxy:
def __init__(self, init_fn):
self._init_fn = init_fn
self._obj = None
@@ -150,7 +157,7 @@ class LazyProxy:
return getattr(self._obj, name)
def __setattr__(self, name, value):
if name in ['_init_fn', '_obj']:
if name in ["_init_fn", "_obj"]:
super().__setattr__(name, value)
else:
self._initialize_obj()
@@ -172,6 +179,7 @@ class LazyProxy:
def initialize_driver():
import torch
if torch.version.hip is not None:
return HIPDriver()
elif torch.cuda.is_available():

View File

@@ -1,10 +1,8 @@
class OutOfResources(Exception):
def __init__(self, required, limit, name):
self.message = f'out of resource: {name}, '\
f'Required: {required}, '\
f'Hardware limit: {limit}'
self.message += '. Reducing block sizes or `num_stages` may help.'
self.message = f"out of resource: {name}, " f"Required: {required}, " f"Hardware limit: {limit}"
self.message += ". Reducing block sizes or `num_stages` may help."
self.required = required
self.limit = limit
self.name = name

View File

@@ -74,11 +74,15 @@ class BlockPointerHandle:
def wrap_ret(compute_ret_ty):
def wrapper(fn):
def wrapped(*args, **kwargs):
ret = fn(*args, **kwargs)
return TensorHandle(ret.data, compute_ret_ty(*args, **kwargs))
return wrapped
return wrapper
@@ -249,11 +253,13 @@ class Builder:
# ternary functions
def ternary_op(self, lhs, rhs, other, op):
return TensorHandle(op(lhs.data, rhs.data, other.data), other.dtype)
create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where)
# unary functions
def unary_op(self, arg, op):
return TensorHandle(op(arg.data), arg.dtype)
create_exp = lambda self, arg: self.unary_op(arg, np.exp)
create_cos = lambda self, arg: self.unary_op(arg, np.cos)
create_sin = lambda self, arg: self.unary_op(arg, np.sin)
@@ -279,7 +285,8 @@ class Builder:
dtype_tt = ptr.dtype.element_ty
return TensorHandle(ptr.data + (dtype_tt.primitive_bitwidth // 8) * offset.data.astype(np.uint64), ptr.dtype)
def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, is_volatile):
def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy,
is_volatile):
ptrs, masks = ptr.materialize_pointers(boundary_check)
assert padding_option is None
other = None
@@ -297,6 +304,7 @@ class Builder:
def create_int_to_ptr(self, val, dst_ty):
return TensorHandle(val.data.astype(np.uint64), dst_ty)
# def create_cat(self, lhs, rhs):
# pass
@@ -360,7 +368,10 @@ class Builder:
def patch_attr(obj, name, member, builder):
new_member = lambda *args, member=member, **kwargs: (member(*args, **{k: v for k, v in kwargs.items() if k != '_builder'}, _builder=builder))
new_member = lambda *args, member=member, **kwargs: (member(*args, **
{k: v
for k, v in kwargs.items()
if k != "_builder"}, _builder=builder))
setattr(obj, name, new_member)
@@ -384,8 +395,8 @@ def _patch_lang_core(lang, builder):
def _new_reduce(input, axis, combine_fn):
fn = combine_fn.fn.__name__
mapping = {
'maximum': np.max,
'_sum_combine': np.sum,
"maximum": np.max,
"_sum_combine": np.sum,
}
ret = mapping[fn](input.handle.data, axis=axis)
ret_type = tl.block_type(input.dtype, ret.shape)
@@ -397,15 +408,16 @@ def _patch_lang_core(lang, builder):
def _patch_lang_math(lang, builder):
math = lang.math
mapping = {
'abs': 'abs',
'acos': 'arccos',
'asin': 'arcsin',
'exp2': 'exp2',
'log2': 'log2',
'max': 'maximum',
"abs": "abs",
"acos": "arccos",
"asin": "arcsin",
"exp2": "exp2",
"log2": "log2",
"max": "maximum",
}
def make_numpy(name):
def impl(*args, **kwargs):
ret_type = args[0].type # TODO: incorrect
ret_dtype = args[0].dtype # TODO: incorrect
@@ -414,15 +426,18 @@ def _patch_lang_math(lang, builder):
ret = getattr(np, mapping[name])(*args, **kwargs)
ret = tl.core.tensor(TensorHandle(ret, ret_dtype), ret_type)
return ret
return impl
def make_fallback(name):
def fallback(*args, **kwargs):
raise NotImplementedError(f"""
{name} not supported in interpreter mode: no known numpy implementation.
If you think that {name} in fact does have a numpy implementation, please add it
to the mapping in python/triton/interpreter/new_interpreter.py:_patch_lang_math.
""")
return fallback
for name, member in inspect.getmembers(math):
@@ -438,7 +453,7 @@ def _implicit_cvt(arg):
ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
handle = TensorHandle(np.array([arg], dtype=np.int32), ty)
return tl.tensor(handle, ty)
if hasattr(arg, 'data_ptr'):
if hasattr(arg, "data_ptr"):
ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty)
return tl.tensor(handle, ty)
@@ -453,28 +468,29 @@ def _unwrap(tensor):
builder = Builder()
RESERVED_KWS = ['num_warps', 'num_stages', 'num_ctas', 'enable_warp_specialization', 'enable_fp_fusion']
RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_warp_specialization", "enable_fp_fusion"]
class GridExecutor:
def __init__(self, fn, arg_names, grid):
from .jit import _normalize_ty # TODO: modularize
self.fn = fn
self.arg_names = arg_names
self.grid = grid
__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
self.constexprs = [name for name in arg_names if __annotations__.get(name) == 'constexpr']
self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"]
def _patch_lang(self, builder):
lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]]
assert len(lang) == 1, "triton.language must be visible from within jit'd function"
_patch_lang_tensor(getattr(lang[0], 'tensor'), builder)
_patch_lang_tensor(getattr(lang[0], "tensor"), builder)
_patch_lang_core(lang[0], builder)
_patch_lang_math(lang[0], builder)
def __call__(self, *args_dev, **kwargs):
args_hst = [_unwrap(arg).cpu() if hasattr(arg, 'data_ptr') else arg for arg in args_dev]
args_hst = [_unwrap(arg).cpu() if hasattr(arg, "data_ptr") else arg for arg in args_dev]
# removes reserved keywords from kwargs
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS}
# remaps core language functions to interpreted ones
@@ -486,7 +502,7 @@ class GridExecutor:
# iterate through grid
grid = self.grid(args) if callable(self.grid) else self.grid
assert len(grid) <= 3
grid = grid + (1,) * (3 - len(grid))
grid = grid + (1, ) * (3 - len(grid))
builder.set_grid_dim(*grid)
for x in range(grid[0]):
for y in range(grid[1]):
@@ -495,7 +511,7 @@ class GridExecutor:
self.fn(**args)
# copy arguments back to propagate side-effects
for arg_dev, arg_hst in zip(args_dev, args_hst):
if hasattr(arg_dev, 'data_ptr'):
if hasattr(arg_dev, "data_ptr"):
_unwrap(arg_dev).copy_(arg_hst.to(arg_dev.device))
@@ -504,17 +520,18 @@ class InterpretedFunction:
def _patch_lang(self, builder):
lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]]
assert len(lang) == 1, "triton.language must be visible from within jit'd function"
_patch_lang_tensor(getattr(lang[0], 'tensor'), builder)
_patch_lang_tensor(getattr(lang[0], "tensor"), builder)
_patch_lang_core(lang[0], builder)
def __init__(self, fn) -> None:
self.fn = fn
def run(*args, **kwargs):
grid = kwargs['grid']
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS + ['grid']}
grid = kwargs["grid"]
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS + ["grid"]}
return GridExecutor(self.fn, self.arg_names, grid)(*args, **kwargs)
self.run = run
signature = inspect.signature(fn)
self.arg_names = [v.name for v in signature.parameters.values()]

View File

@@ -5,48 +5,48 @@ import functools
import hashlib
import inspect
import os
import subprocess
import textwrap
from collections import defaultdict, namedtuple
from typing import (Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast,
overload)
from functools import cached_property
from typing import Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast, overload
from .._C.libtriton.triton import TMAInfos
from ..common.backend import get_backend, path_to_ptxas
from ..language.core import dtype
from ..common.backend import get_backend, get_cuda_version_key
from .interpreter import InterpretedFunction
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
TRITON_VERSION = "2.1.0"
def get_cuda_stream(idx=None):
if idx is None:
idx = get_current_device()
try:
from torch._C import _cuda_getCurrentRawStream
return _cuda_getCurrentRawStream(idx)
except ImportError:
import torch
return torch.cuda.current_stream(idx).cuda_stream
def get_current_device():
import torch
return torch.cuda.current_device()
def set_current_device(idx):
import torch
torch.cuda.set_device(idx)
def get_device_capability(idx):
import torch
return torch.cuda.get_device_capability(idx)
T = TypeVar('T')
T = TypeVar("T")
# -----------------------------------------------------------------------------
# Dependencies Finder
@@ -72,7 +72,8 @@ class DependenciesFinder(ast.NodeVisitor):
lhs = self.visit(node.value)
while isinstance(lhs, ast.Attribute):
lhs = self.visit(lhs.value)
if lhs is None or (getattr(lhs, "__name__", "") == "triton" or getattr(lhs, "__name__", "").endswith(".triton")):
if lhs is None or (getattr(lhs, "__name__", "") == "triton"
or getattr(lhs, "__name__", "").endswith(".triton")):
return None
return getattr(lhs, node.attr)
@@ -82,55 +83,26 @@ class DependenciesFinder(ast.NodeVisitor):
return
if inspect.isbuiltin(func):
return
if func.__module__ and (func.__module__.startswith('triton.') or '.triton.' in func.__module__):
if func.__module__ and (func.__module__.startswith("triton.") or ".triton." in func.__module__):
return
assert isinstance(func, JITFunction), f"Function \"{func.__name__}\" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this"
assert isinstance(
func, JITFunction
), f'Function "{func.__name__}" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this'
if func.hash is None:
tree = ast.parse(func.src)
finder = DependenciesFinder(func.__globals__, func.src)
finder.visit(tree)
func.hash = finder.ret
noinline = str(getattr(func, 'noinline', False))
noinline = str(getattr(func, "noinline", False))
self.ret = (self.ret + func.hash + noinline).encode("utf-8")
self.ret = hashlib.sha1(self.ret).hexdigest()
# -----------------------------------------------------------------------------
# JITFunction
# -----------------------------------------------------------------------------
@functools.lru_cache()
def version_key():
import pkgutil
contents = []
# frontend
with open(__file__, "rb") as f:
contents += [hashlib.sha1(f.read()).hexdigest()]
# compiler
compiler_path = os.path.join(TRITON_PATH, 'compiler')
for lib in pkgutil.iter_modules([compiler_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.sha1(f.read()).hexdigest()]
# backend
libtriton_hash = hashlib.sha1()
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
while True:
chunk = f.read(1024 ** 2)
if not chunk:
break
libtriton_hash.update(chunk)
contents.append(libtriton_hash.hexdigest())
# language
language_path = os.path.join(TRITON_PATH, 'language')
for lib in pkgutil.iter_modules([language_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.sha1(f.read()).hexdigest()]
# ptxas version
ptxas = path_to_ptxas()[0]
ptxas_version = hashlib.sha1(subprocess.check_output([ptxas, "--version"])).hexdigest()
return '-'.join(TRITON_VERSION) + '-' + ptxas_version + '-' + '-'.join(contents)
def _normalize_ty(ty) -> str:
if isinstance(ty, type):
return ty.__name__
@@ -139,6 +111,85 @@ def _normalize_ty(ty) -> str:
return repr(ty)
class KernelParam:
"""Represents a parameter to a @jit'ed function.
A parameter is just the name plus metadata; a parameter plus a value is a
KernelArg.
"""
def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool):
self.num = num
self._param = param
self.do_not_specialize = do_not_specialize
@cached_property
def name(self):
return self._param.name
@cached_property
def annotation(self):
if not self._param.annotation or self._param.annotation == inspect.Parameter.empty:
return ""
return _normalize_ty(self._param.annotation)
@cached_property
def is_constexpr(self):
return "constexpr" in self.annotation
@property
def default(self):
return self._param.default
@property
def has_default(self):
return self._param.default != inspect.Parameter.empty
class KernelArg:
"""Represents an argument to a @jit'ed function.
An argument is a parameter plus a value.
"""
def __init__(self, value, param):
self.value = value
self.param = param
@property
def name(self):
return self.param.name
def signature_key(self):
annotation = self.param.annotation
if "Tensor" in annotation:
return self.value.dtype
elif annotation == "bool":
return "i1"
elif annotation == "float":
return "fp32"
else:
return JITFunction._key_of(self.value)
def specialization_key(self):
assert not self.param.do_not_specialize
try:
return (self.value.data_ptr() % JITFunction.divisibility == 0, )
except AttributeError:
pass
if isinstance(self.value, int):
# bool is a subclass of int, so we don't check explicitly above.
return (
self.value % JITFunction.divisibility == 0,
self.value % JITFunction.divisibility_8 == 0,
self.value == 1,
)
return (False, )
class KernelInterface(Generic[T]):
run: T
@@ -152,7 +203,6 @@ class KernelInterface(Generic[T]):
class JITFunction(KernelInterface[T]):
# Hook for inspecting compiled functions and modules
cache_hook = None
divisibility = 16
@@ -169,44 +219,44 @@ class JITFunction(KernelInterface[T]):
elif isinstance(arg, bool):
return "i1"
elif isinstance(arg, int):
if -2**31 <= arg and arg <= 2**31 - 1:
if -(2**31) <= arg and arg <= 2**31 - 1:
return "i32"
elif 2**63 <= arg and arg <= 2**64 - 1:
return "u64"
else:
return "i64"
elif isinstance(arg, float):
return 'fp32'
return "fp32"
elif arg is None:
return None
else:
raise TypeError(f'Unsupported type {type(arg)} for {arg}')
raise TypeError(f"Unsupported type {type(arg)} for {arg}")
@staticmethod
def _device_of(arg):
if hasattr(arg, "device"):
if hasattr(arg.device, 'type'):
return arg.device.type
return ''
try:
return arg.device.type
except AttributeError:
return ""
@staticmethod
def _pinned_memory_of(arg):
if hasattr(arg, "is_pinned"):
if isinstance(arg.is_pinned, Callable):
return arg.is_pinned()
return False
try:
return arg.is_pinned()
except (AttributeError, TypeError):
return False
@staticmethod
def _spec_of(arg):
if hasattr(arg, "data_ptr"):
return (arg.data_ptr() % JITFunction.divisibility == 0)
return arg.data_ptr() % JITFunction.divisibility == 0
elif isinstance(arg, int):
return (arg % 16 == 0, arg == 1)
return (arg is None, )
# TODO(jlebar): Fold this into the KernelArg class.
def _get_config(self, *args):
def is_divisible_by_16(x):
if hasattr(x, "data_ptr"):
return x.data_ptr() % JITFunction.divisibility == 0
@@ -222,28 +272,38 @@ class JITFunction(KernelInterface[T]):
if x is None:
return True
return False
divisible_by_16 = {i for i, arg in enumerate(
args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
divisible_by_8 = {i for i, arg in enumerate(
args) if is_divisible_by_8(arg) and i not in self.do_not_specialize}
divisible_by_16 = {
param.num
for param, arg in zip(self.params, args)
if is_divisible_by_16(arg) and not param.do_not_specialize
}
divisible_by_8 = {
param.num
for param, arg in zip(self.params, args)
if is_divisible_by_8(arg) and not param.do_not_specialize
}
equal_to_1 = {
i for i, arg in enumerate(args) if isinstance(
arg, int) and not isinstance(
arg, bool) and arg == 1 and i not in self.do_not_specialize}
param.num
for param, arg in zip(self.params, args)
if isinstance(arg, int) and not isinstance(arg, bool) and arg == 1 and not param.do_not_specialize
}
# folded equal_to_1 and None
# TODO: method to collect all folded args
none_args = {i for i, arg in enumerate(args) if arg is None and i not in self.do_not_specialize}
none_args = {param.num for param, arg in zip(self.params, args) if arg is None and not param.do_not_specialize}
ids_of_folded_args = equal_to_1 | none_args
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])(
tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args), tuple(divisible_by_8))
return namedtuple("instance_descriptor",
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])( #
tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args),
tuple(divisible_by_8))
# return _triton.code_gen.instance_descriptor(divisible_by_16,
# equal_to_1)
@staticmethod
def _type_of(key):
# None are nullptr -- implicitly converted to *i8
# `None` is nullptr. Implicitly convert to *i8.
if key is None:
return '*i8'
return "*i8"
dtype_str = str(key).split(".")[-1]
tys = {
"bool": "i1",
@@ -281,187 +341,265 @@ class JITFunction(KernelInterface[T]):
constants = dict(zip(self.constexprs, constexpr_key))
return constants
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization,enable_fp_fusion, extern_libs, configs):
def _call_hook(
self,
key,
signature,
device,
constants,
num_warps,
num_ctas,
num_stages,
waves_per_eu,
matrix_instr_nonkdim,
enable_warp_specialization,
enable_fp_fusion,
extern_libs,
configs,
):
if JITFunction.cache_hook is None:
return False
name = self.fn.__name__
module = self.fn.__module__
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
arg_reprs = ', '.join([f'{param.name}: {ty}' for param, ty in zip(self.params, key[1])])
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, matrix_instr_nonkdim={matrix_instr_nonkdim}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs}), enable_fp_fusion={enable_fp_fusion}]({arg_reprs})"
key = str(key)
class LegacyCompiler:
def __init__(self, module, name):
self.module = module
self.name = name
pass
kwargs = dict(signature=signature, device=device, constants=constants,
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs,
configs=configs)
kwargs = dict(
signature=signature,
device=device,
constants=constants,
num_warps=num_warps,
num_ctas=num_ctas,
num_stages=num_stages,
waves_per_eu=waves_per_eu,
enable_warp_specialization=enable_warp_specialization,
enable_fp_fusion=enable_fp_fusion,
extern_libs=extern_libs,
configs=configs)
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={
"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False)
def _get_arg_specialization_key(self, arg_name, arg):
arg_annotation = self.__annotations__.get(arg_name, '')
if arg_annotation == '':
return (arg.data_ptr() % JITFunction.divisibility == 0) if hasattr(arg, "data_ptr") \
else (arg % JITFunction.divisibility == 0, arg % JITFunction.divisibility_8 == 0, arg == 1) if isinstance(arg, int) \
else (False,)
elif 'Tensor' in arg_annotation:
return (arg.data_ptr() % JITFunction.divisibility == 0)
elif 'int' in arg_annotation or 'bool' in arg_annotation:
return (arg % JITFunction.divisibility == 0, arg % JITFunction.divisibility_8 == 0, arg == 1)
else:
return (False,)
def _get_arg_sig_key(self, arg_name, arg) -> str:
arg_annotation = self.__annotations__.get(arg_name, '')
if 'Tensor' in arg_annotation:
return arg.dtype
elif arg_annotation == 'bool':
return "i1"
elif arg_annotation == 'float':
return 'fp32'
else:
return self._key_of(arg)
return JITFunction.cache_hook(
key=key,
repr=repr,
fn=LegacyCompiler(module, name),
compile={"key": key, **kwargs},
is_manual_warmup=False,
already_compiled=False,
)
def _conclude_device_type(self, device_types: List[str], pinned_memory_flags: List[bool]) -> str:
device_types = [device_type for device_type in device_types if device_type != '']
device_types = [device_type for device_type in device_types if device_type != ""]
# Return cuda if one of the input tensors is cuda
if 'cuda' in device_types:
if "cuda" in device_types:
import torch
return 'hip' if torch.version.hip else 'cuda'
is_cpu = all(device_type == 'cpu' for device_type in device_types)
return "hip" if torch.version.hip else "cuda"
is_cpu = all(device_type == "cpu" for device_type in device_types)
is_pinned_memory = any(pinned_memory_flag for pinned_memory_flag in pinned_memory_flags)
# Return cuda if all the input tensors are cpu while the memory is pinned
if is_cpu and is_pinned_memory:
return 'cuda'
return "cuda"
return device_types[0] if len(device_types) > 0 else 'cuda'
return device_types[0] if len(device_types) > 0 else "cuda"
def _make_launcher(self):
regular_args = [arg for i, arg in enumerate(
self.arg_names) if i not in self.constexprs]
constexpr_args = [arg for i, arg in enumerate(
self.arg_names) if i in self.constexprs]
def run(self, *args, **kwargs):
from ..compiler import CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps
def regular_args_v(args_proxy):
return [args_proxy[arg_name] for arg_name in regular_args]
# Get a compiler-flags arg like `num_warps` and remove it from kwargs.
def get_special_arg(name: str, default=None):
if name not in kwargs:
return default
ret = kwargs[name]
del kwargs[name]
return ret
def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, enable_fp_fusion, extern_libs, stream, warmup, device, device_type):
from ..compiler import (CompiledKernel, compile,
get_arch_default_num_stages,
get_arch_default_num_warps)
sig_key = tuple([self._get_arg_sig_key(arg_name, args_proxy[arg_name]) for arg_name in regular_args])
constexpr_key = tuple([args_proxy[arg_name] for arg_name in constexpr_args])
specializations = []
for i, arg_name in enumerate(regular_args):
if i in self.do_not_specialize:
continue
specializations += [self._get_arg_specialization_key(arg_name, args_proxy[arg_name])]
grid = get_special_arg("grid")
num_warps = get_special_arg("num_warps")
num_ctas = get_special_arg("num_ctas", 1)
num_stages = get_special_arg("num_stages")
waves_per_eu = get_special_arg("waves_per_eu", 0)
matrix_instr_nonkdim = get_special_arg("matrix_instr_nonkdim", 0)
enable_warp_specialization = get_special_arg("enable_warp_specialization", False)
enable_fp_fusion = get_special_arg("enable_fp_fusion", True)
extern_libs = get_special_arg("extern_libs")
stream = get_special_arg("stream")
warmup = get_special_arg("warmup", False)
device = get_special_arg("device")
device_type = get_special_arg("device_type")
spec_key = tuple(specializations)
assert num_ctas > 0
assert grid is not None
if callable(grid):
grid = grid(args_proxy)
grid_size = len(grid)
grid_0 = grid[0]
grid_1 = grid[1] if grid_size > 1 else 1
grid_2 = grid[2] if grid_size > 2 else 1
if device_type is None:
device_types = [self._device_of(arg) for arg in regular_args_v(args_proxy)]
device_types = [_device_type for _device_type in device_types if _device_type != '']
device_type = self._conclude_device_type(device_types, [self._pinned_memory_of(arg) for arg in
regular_args_v(args_proxy)])
# Bind the remaining arguments to `fn`.
bound_args = self.signature.bind(*args, **kwargs)
bound_args.apply_defaults()
device_backend = None
if device_type not in ['cuda']:
device_backend = get_backend(device_type)
if device_backend is None:
raise ValueError('Cannot find backend for ' + device_type)
assert len(bound_args.arguments) == len(self.params)
args = [KernelArg(arg_value, param) for (_, arg_value), param in zip(bound_args.arguments.items(), self.params)]
if device is None:
if device_type in ['cuda']:
device = get_current_device()
set_current_device(device)
else:
device = device_backend.get_current_device()
device_backend.set_current_device(device)
if stream is None and not warmup:
if device_type in ['cuda']:
stream = get_cuda_stream(device)
else:
stream = device_backend.get_stream()
non_constexpr_arg_values = [arg.value for arg in args if not arg.param.is_constexpr]
if num_warps is None:
num_warps = get_arch_default_num_warps(device_type)
if num_stages is None:
num_stages = get_arch_default_num_stages(device_type)
sig_key = tuple(arg.signature_key() for arg in args if not arg.param.is_constexpr)
spec_key = tuple(arg.specialization_key() for arg in args if not arg.param.do_not_specialize)
constexpr_key = tuple(arg.value for arg in args if arg.param.is_constexpr)
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, enable_fp_fusion, self.debug)
if extern_libs is not None:
key = (key, tuple(extern_libs.items()))
assert num_ctas > 0
assert grid is not None
if callable(grid):
# Arguments are passed as a dict to `grid`, by contract.
# TODO(jlebar): In the new launch API, pass the compiler flags as a
# second parameter to `grid`.
grid = grid(dict(bound_args.arguments))
grid_size = len(grid)
grid_0 = grid[0]
grid_1 = grid[1] if grid_size > 1 else 1
grid_2 = grid[2] if grid_size > 2 else 1
if device_type is None:
device_types = [self._device_of(arg) for arg in non_constexpr_arg_values]
device_types = [_device_type for _device_type in device_types if _device_type != ""]
device_type = self._conclude_device_type(device_types,
[self._pinned_memory_of(arg) for arg in non_constexpr_arg_values])
bin = self.cache[device].get(key, None)
if bin is not None:
# build dict of constant values
args = regular_args_v(args_proxy)
# Create tensormaps and append to args
args = bin.assemble_tensormap_to_arg(args)
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
return bin
# kernel not cached -- compile
device_backend = None
if device_type not in ["cuda"]:
device_backend = get_backend(device_type)
if device_backend is None:
raise ValueError("Cannot find backend for " + device_type)
if device is None:
if device_type in ["cuda"]:
device = get_current_device()
set_current_device(device)
else:
# build dict of constant values
args = regular_args_v(args_proxy)
all_args = tuple([args_proxy[arg_name] for arg_name in self.arg_names])
configs = self._get_config(*all_args),
constants = self._make_constants(constexpr_key)
constants.update({i: None for i, arg in enumerate(all_args) if arg is None})
constants.update({i: 1 for i in configs[0].equal_to_1})
# build kernel signature -- doesn't include specialized arguments
signature = {i: self._type_of(self._key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs}
# build stub signature -- includes arguments that are specialized
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {i} is not supported")
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, enable_fp_fusion, extern_libs, configs):
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=matrix_instr_nonkdim, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
# Create tensormaps and append to args
args = bin.assemble_tensormap_to_arg(args)
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
self.cache[device][key] = bin
return bin
device = device_backend.get_current_device()
device_backend.set_current_device(device)
if stream is None and not warmup:
if device_type in ["cuda"]:
stream = get_cuda_stream(device)
else:
stream = device_backend.get_stream()
if num_warps is None:
num_warps = get_arch_default_num_warps(device_type)
if num_stages is None:
num_stages = get_arch_default_num_stages(device_type)
if device_type in ["cuda"]:
version_key = get_cuda_version_key()
else:
version_key = device_backend.get_version_key()
key = (
version_key,
sig_key,
constexpr_key,
spec_key,
num_warps,
num_ctas,
num_stages,
waves_per_eu,
matrix_instr_nonkdim,
enable_warp_specialization,
enable_fp_fusion,
self.debug,
)
if extern_libs is not None:
key = (key, tuple(extern_libs.items()))
# Kernel is not cached; we have to compile.
if key not in self.cache[device]:
configs = (self._get_config(*[arg.value for arg in args]), )
constants = {
arg.param.num: arg.value
for arg in args
if arg.param.is_constexpr or arg.param.num in configs[0].equal_to_1 or arg.value is None
}
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {i} is not supported")
# Build kernel signature -- doesn't include constexpr arguments.
signature = {
arg.param.num: self._type_of(self._key_of(arg.value))
for arg in args
if not arg.param.is_constexpr
}
if self._call_hook(
key,
signature,
device,
constants,
num_warps,
num_ctas,
num_stages,
waves_per_eu,
matrix_instr_nonkdim,
enable_warp_specialization,
enable_fp_fusion,
extern_libs,
configs,
):
return None
# create a wrapper to call launcher_body
args_map = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
args_signature = ', '.join(name if dflt == inspect._empty else f'{name} = triton.language.dtype(\'{dflt}\')' if dtype.is_dtype(f'{dflt}') else f'{name} = {dflt}' for name, dflt in zip(self.arg_names, self.arg_defaults))
args_signature = args_signature + ', ' if len(args_signature) > 0 else ''
src = f"""
import triton
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, waves_per_eu=0, matrix_instr_nonkdim=0, enable_warp_specialization=False, enable_fp_fusion=True, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
return launcher_body({{{args_map}}}, grid, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, enable_fp_fusion, extern_libs, stream, warmup, device, device_type)
"""
scope = {"launcher_body": launcher_body}
exec(src, scope)
return scope[self.fn.__name__]
self.cache[device][key] = compile(
self,
signature=signature,
device=device,
constants=constants,
num_warps=num_warps,
num_ctas=num_ctas,
num_stages=num_stages,
waves_per_eu=waves_per_eu,
matrix_instr_nonkdim=matrix_instr_nonkdim,
enable_warp_specialization=enable_warp_specialization,
enable_fp_fusion=enable_fp_fusion,
extern_libs=extern_libs,
configs=configs,
debug=self.debug,
device_type=device_type,
)
bin = self.cache[device][key]
if not warmup:
bin.c_wrapper(
grid_0,
grid_1,
grid_2,
bin.num_warps,
bin.num_ctas,
bin.clusterDims[0],
bin.clusterDims[1],
bin.clusterDims[2],
bin.shared,
stream,
bin.cu_function,
CompiledKernel.launch_enter_hook,
CompiledKernel.launch_exit_hook,
bin,
*bin.assemble_tensormap_to_arg(non_constexpr_arg_values),
)
return bin
def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None):
do_not_specialize = do_not_specialize if do_not_specialize else []
self.fn = fn
self.module = fn.__module__
self.version = version
# function signature information
signature = inspect.signature(fn)
self.arg_names = [v.name for v in signature.parameters.values()]
self.arg_defaults = [v.default for v in signature.parameters.values()]
self.has_defaults = any(v != inspect._empty for v in self.arg_defaults)
self.signature = inspect.signature(fn)
self.do_not_specialize = do_not_specialize
self.params = []
for i, param in enumerate(self.signature.parameters.values()):
dns = do_not_specialize and (i in do_not_specialize or param.name in do_not_specialize)
self.params.append(KernelParam(i, param, dns))
# function source code (without decorators)
self.src = textwrap.dedent(inspect.getsource(fn))
self.src = self.src[self.src.find("def"):]
@@ -470,22 +608,18 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
self.hash = None
# JITFunction can be instantiated as kernel
# when called with a grid using __getitem__
self.kernel_decorators = []
self.kernel = None
self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug
self.noinline = noinline
# annotations
self.__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
# index of constexprs
self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if 'constexpr' in ty]
# specialization hints
regular_args = [arg for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
self.do_not_specialize = {regular_args.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize}
# tma info
self.tensormaps_info = TMAInfos()
# launcher
self.run = self._make_launcher()
# TODO(jlebar): Remove uses of these fields outside this file, then
# remove the fields here.
self.arg_names = [p.name for p in self.params]
self.constexprs = [p.num for p in self.params if p.is_constexpr]
# re-use docs of wrapped function
self.__doc__ = fn.__doc__
self.__name__ = fn.__name__
@@ -498,7 +632,7 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
if self.hash is None:
dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src)
dependencies_finder.visit(self.parse())
self.hash = dependencies_finder.ret + version_key()
self.hash = dependencies_finder.ret
return self.hash
def warmup(self, *args, **kwargs):
@@ -518,14 +652,10 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
def __setattr__(self, name, value):
# - when kernel decorators change, cached kernel
# needs to be cleared
if name == 'kernel_decorators':
self.kernel = None
super(JITFunction, self).__setattr__(name, value)
# - when `.src` attribute is set, cache path needs
# to be reinitialized
if name == 'src':
if name == "src":
self.hash = None
def __repr__(self):
@@ -591,12 +721,14 @@ def jit(
debug=debug,
noinline=noinline,
)
if fn is not None:
return decorator(fn)
else:
return decorator
# -----------------------------------------------------------------------------
# Utilities for mocking tensors
# -----------------------------------------------------------------------------
@@ -607,10 +739,10 @@ class MockTensor:
Can be used in place of real tensors when calling:
kernel.warmup(MockTensor(torch.float32), ...)
"""
@staticmethod
def wrap_dtype(arg):
if arg.__class__.__name__ == "dtype" and\
arg.__module__ == "torch":
if arg.__class__.__name__ == "dtype" and arg.__module__ == "torch":
return MockTensor(arg)
return arg
@@ -623,6 +755,7 @@ class MockTensor:
class TensorWrapper:
def __init__(self, base, dtype):
self.dtype = dtype
self.base = base
@@ -637,7 +770,7 @@ class TensorWrapper:
return self.base.stride(i)
def __str__(self) -> str:
return f'TensorWrapper[{self.dtype}]({self.base})'
return f"TensorWrapper[{self.dtype}]({self.base})"
def element_size(self):
return self.base.element_size()
@@ -655,4 +788,4 @@ def reinterpret(tensor, dtype):
# A new wrapper is needed around an unwrapped tensor.
return TensorWrapper(tensor, dtype)
else:
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')
raise TypeError(f"Cannot reinterpret a {type(tensor)}.")

View File

@@ -78,10 +78,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None):
return torch.mean(torch.tensor(ret)).item()
def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
quantiles=None,
fast_flush=True,
return_mode="mean"):
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean"):
assert return_mode in ["min", "max", "mean", "median"]
import torch
"""
@@ -261,11 +258,12 @@ class Benchmark:
class Mark:
def __init__(self, fn, benchmarks):
self.fn = fn
self.benchmarks = benchmarks
def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, **kwrags):
def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False, **kwrags):
import os
import matplotlib.pyplot as plt
@@ -321,24 +319,36 @@ class Mark:
if save_path:
plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
df = df[x_names + bench.line_names]
if diff_col and df.shape[1] == 2:
col0, col1 = df.columns.tolist()
df['Diff'] = df[col1] - df[col0]
if print_data:
print(bench.plot_name + ':')
print(df)
if save_path:
df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False)
return df
def run(self, show_plots=False, print_data=False, save_path='', **kwargs):
def run(self, show_plots=False, print_data=False, save_path='', return_df=False, **kwargs):
has_single_bench = isinstance(self.benchmarks, Benchmark)
benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks
result_dfs = []
if save_path:
html = open(os.path.join(save_path, "results.html"), "w")
html.write("<html><body>\n")
for bench in benchmarks:
self._run(bench, save_path, show_plots, print_data, **kwargs)
result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs))
if save_path:
html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
if save_path:
html.write("</body></html>\n")
if return_df:
if has_single_bench:
return result_dfs[0]
else:
return result_dfs
return None
def perf_report(benchmarks):
@@ -393,12 +403,15 @@ def get_max_tensorcore_tflops(dtype, clock_rate, backend=None, device=None):
tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
return tflops
# create decorator that wraps test function into
# a cuda-memcheck system call
def cuda_memcheck(**target_kwargs):
def decorator(test_fn):
@functools.wraps(test_fn)
def wrapper(*args, **kwargs):
import psutil
@@ -416,7 +429,9 @@ def cuda_memcheck(**target_kwargs):
assert "ERROR SUMMARY: 0 errors" in str(out.stdout)
else:
test_fn(*args, **kwargs)
return wrapper
return decorator
@@ -424,22 +439,18 @@ def cuda_memcheck(**target_kwargs):
def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215):
try:
subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"])
subprocess.check_output(
[
"nvidia-smi",
"-i",
"0",
f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}",
]
)
subprocess.check_output(
[
"nvidia-smi",
"-i",
"0",
f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}",
]
)
subprocess.check_output([
"nvidia-smi",
"-i",
"0",
f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}",
])
subprocess.check_output([
"nvidia-smi",
"-i",
"0",
f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}",
])
cur_sm_clock = nvsmi(["clocks.current.sm"])[0]
cur_mem_clock = nvsmi(["clocks.current.memory"])[0]
assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz"

View File

@@ -8,8 +8,8 @@ from typing import Any, Tuple
from triton.common import _build
from triton.common.backend import BaseBackend, register_backend
from triton.compiler.make_launcher import get_cache_manager, version_key, make_so_cache_key
from triton.common.backend import BaseBackend, register_backend, compute_core_version_key
from triton.compiler.make_launcher import get_cache_manager, make_so_cache_key
from triton.compiler.utils import generate_cu_signature
from triton.runtime import jit
from triton.runtime.driver import HIPDriver
@@ -25,7 +25,7 @@ else:
def make_stub(name, signature, constants, ids, **kwargs):
# name of files that are cached
so_cache_key = make_so_cache_key(version_key(), signature, constants, ids, **kwargs)
so_cache_key = make_so_cache_key(compute_core_version_key(), signature, constants, ids, **kwargs)
so_cache_manager = get_cache_manager(so_cache_key)
so_name = f"{name}.so"
# retrieve stub from cache if it exists
@@ -414,11 +414,21 @@ def llir_to_amdgcn_and_hsaco(mod: Any, gfx_arch: str, gfx_triple: str, gfx_featu
class HIPBackend(BaseBackend):
_cached_rocm_version_key = None
def __init__(self, device_type: str) -> None:
super(HIPBackend, self).__init__(device_type)
self.driver = HIPDriver()
self.stub_so_path = ""
def get_version_key(self):
if self._cached_rocm_version_key is None:
key = compute_core_version_key()
### TODO: Append ROCM version here if needed
self._cached_rocm_version_key = key
return self._cached_rocm_version_key
def is_standalone(self):
return not HIP_BACKEND_MODE
@@ -500,7 +510,6 @@ class HIPBackend(BaseBackend):
return arch
def make_launcher_stub(self, name, signature, constants, ids):
# print("HIPBackend.make_launcher_stub")
self.stub_so_path = make_stub(name, signature, constants, ids)
return self.stub_so_path
@@ -517,4 +526,4 @@ class HIPBackend(BaseBackend):
return _triton.get_num_warps(module)
def get_matrix_core_version(self):
return gpu_matrix_core_version()
return gpu_matrix_core_version()

View File

@@ -141,8 +141,7 @@ class ExternLibrary(ABC):
f.write(file_str)
f.close()
if self._format:
subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file],
stdout=subprocess.PIPE).communicate()
subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file], stdout=subprocess.PIPE).communicate()
subprocess.Popen(["isort", output_file], stdout=subprocess.PIPE).communicate()
@@ -208,56 +207,36 @@ class Libdevice(ExternLibrary):
# Group functions together by renaming.
renaming = {
'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh',
'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': 'add_rn',
'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru',
'dadd_rz': 'add_rz', 'fadd_rz': 'add_rz', 'asinf': 'asin',
'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2',
'atanhf': 'atanh', 'brevll': 'brev', 'cbrtf': 'cbrt',
'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign',
'cosf': 'cos', 'coshf': 'cosh', 'cospif': 'cospi',
'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1',
'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn',
'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', 'ddiv_ru': 'div_ru',
'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf',
'erfcf': 'erfc', 'erfcinvf': 'erfcinv', 'erfcxf': 'erfcx',
'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10',
'exp2f': 'exp2', 'expm1f': 'expm1', 'fabsf': 'abs',
'fabs': 'abs', 'fast_fdividef': 'fast_dividef',
'fdimf': 'fdim', 'ffsll': 'ffs', 'floorf': 'floor',
'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn',
'fmaf_ru': 'fma_ru', 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod',
'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb',
'isinff': 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan',
'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn',
'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint',
'llroundf': 'llround', 'logf': 'log', 'log10f': 'log10',
'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb',
'umax': 'max', 'llmax': 'max', 'ullmax': 'max', 'fmaxf': 'max',
'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min',
'fminf': 'min', 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd',
'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', 'dmul_ru': 'mul_ru',
'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz',
'umul24': 'mul24', 'umulhi': 'mulhi', 'mul64hi': 'mulhi',
'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': 'nextafter',
'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf',
'normcdfinvf': 'normcdfinv', 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow',
'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', 'drcp_rd': 'rcp_rd',
'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru',
'drcp_ru': 'rcp_ru', 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz',
'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot',
'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d',
'roundf': 'round', 'rsqrtf': 'rsqrt', 'frsqrt_rn': 'rsqrt_rn',
'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit',
'signbitd': 'signbit', 'sinf': 'sin', 'sinhf': 'sinh',
'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd',
'dsqrt_rd': 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn',
'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', 'fsqrt_rz': 'sqrt_rz',
'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd',
'fsub_rn': 'sub_rn', 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru',
'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz',
'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc',
'y0f': 'y0', 'y1f': 'y1', 'ynf': 'yn'
'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh', 'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn':
'add_rn', 'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru', 'dadd_rz': 'add_rz', 'fadd_rz':
'add_rz', 'asinf': 'asin', 'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2', 'atanhf': 'atanh',
'brevll': 'brev', 'cbrtf': 'cbrt', 'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign', 'cosf': 'cos',
'coshf': 'cosh', 'cospif': 'cospi', 'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1',
'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn', 'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru',
'ddiv_ru': 'div_ru', 'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf', 'erfcf': 'erfc', 'erfcinvf':
'erfcinv', 'erfcxf': 'erfcx', 'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10', 'exp2f': 'exp2',
'expm1f': 'expm1', 'fabsf': 'abs', 'fabs': 'abs', 'fast_fdividef': 'fast_dividef', 'fdimf': 'fdim', 'ffsll':
'ffs', 'floorf': 'floor', 'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn', 'fmaf_ru': 'fma_ru',
'fmaf_rz': 'fma_rz', 'fmodf': 'fmod', 'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb', 'isinff':
'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan', 'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn',
'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint', 'llroundf': 'llround', 'logf': 'log', 'log10f':
'log10', 'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb', 'umax': 'max', 'llmax': 'max', 'ullmax':
'max', 'fmaxf': 'max', 'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min', 'fminf': 'min',
'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd', 'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn',
'dmul_ru': 'mul_ru', 'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz', 'umul24': 'mul24',
'umulhi': 'mulhi', 'mul64hi': 'mulhi', 'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf':
'nextafter', 'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf', 'normcdfinvf': 'normcdfinv',
'popcll': 'popc', 'powif': 'pow', 'powi': 'pow', 'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd',
'drcp_rd': 'rcp_rd', 'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru', 'drcp_ru': 'rcp_ru',
'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz', 'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot',
'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d', 'roundf': 'round', 'rsqrtf': 'rsqrt',
'frsqrt_rn': 'rsqrt_rn', 'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit', 'signbitd': 'signbit',
'sinf': 'sin', 'sinhf': 'sinh', 'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd', 'dsqrt_rd':
'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn', 'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru',
'fsqrt_rz': 'sqrt_rz', 'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd', 'fsub_rn': 'sub_rn',
'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru', 'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz',
'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc', 'y0f': 'y0', 'y1f': 'y1', 'ynf':
'yn'
}
for symbol in self._symbols.values():
@@ -347,8 +326,7 @@ class LLVMDisassembler:
self._ll_file = "/tmp/extern_lib.ll"
def disasm(self, lib_path: str) -> None:
subprocess.Popen([self._path, lib_path, "-o", self.ll_file],
stdout=subprocess.PIPE).communicate()
subprocess.Popen([self._path, lib_path, "-o", self.ll_file], stdout=subprocess.PIPE).communicate()
@property
def ll_file(self) -> str:

View File

@@ -40,10 +40,13 @@ if __name__ == "__main__":
# command-line arguments
parser = ArgumentParser(description=desc)
parser.add_argument("path", help="Path to Python source containing desired kernel in its scope. File will be executed.")
parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile", required=True)
parser.add_argument("path",
help="Path to Python source containing desired kernel in its scope. File will be executed.")
parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile",
required=True)
parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel")
parser.add_argument("--num-stages", "-ns", type=int, default=3, help="Number of stages (meta-parameter of the kernel)")
parser.add_argument("--num-stages", "-ns", type=int, default=3,
help="Number of stages (meta-parameter of the kernel)")
parser.add_argument("--out-name", "-on", type=str, default=None, help="Out name for the compiled kernel")
parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename")
parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True)
@@ -104,7 +107,8 @@ if __name__ == "__main__":
config = triton.compiler.instance_descriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1)
for i in equal_to_1:
constexprs.update({i: 1})
ccinfo = triton.compile(kernel, signature=signature, constants=constexprs, configs=[config], num_warps=args.num_warps, num_stages=args.num_stages)
ccinfo = triton.compile(kernel, signature=signature, constants=constexprs, configs=[config],
num_warps=args.num_warps, num_stages=args.num_stages)
arg_names = []
arg_types = []
for i in signature.keys():

View File

@@ -27,6 +27,7 @@ class KernelLinkerMeta:
class HeaderParser:
def __init__(self) -> None:
import re
@@ -42,7 +43,6 @@ class HeaderParser:
self.kernels = defaultdict(list)
def extract_linker_meta(self, header: str):
for ln in header.splitlines():
if ln.startswith("//"):
m = self.linker_directives.match(ln)
@@ -76,7 +76,7 @@ class HeaderParser:
m = self.c_sig.findall(c_sig)
if len(m):
tys, args = [], []
for (ty, arg_name) in m:
for ty, arg_name in m:
tys.append(ty)
args.append(arg_name)
return tys, args
@@ -84,7 +84,7 @@ class HeaderParser:
raise LinkerError(f"{c_sig} is not a valid argument signature")
def _match_suffix(self, suffix: str, c_sig: str):
args = c_sig.split(',')
args = c_sig.split(",")
s2i = {"c": 1, "d": 16}
num_specs = 0
sizes = []
@@ -110,7 +110,7 @@ class HeaderParser:
if name in self.kernels:
last: KernelLinkerMeta = self.kernels[name][-1]
for (cur, new_) in zip(last.arg_ctypes, ker.arg_ctypes):
for cur, new_ in zip(last.arg_ctypes, ker.arg_ctypes):
if cur != new_:
raise LinkerError(
f"Mismatched signature for kernel {name}: \n\texisting sig is: {','.join(last.arg_ctypes)}\n\tcurrent is: {','.join(ker.arg_ctypes)}"
@@ -152,7 +152,7 @@ void unload_{meta.orig_kernel_name}();
# generate dispatcher function for kernels with different meta-parameter and constant values
def make_default_algo_kernel(meta: KernelLinkerMeta) -> str:
src = f"CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}){{\n"
src += f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n"
src += (f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n")
src += "}\n"
return src
@@ -164,12 +164,22 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -
src += f"CUresult {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, {gen_signature(meta)});\n"
src += "\n"
src += f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{"
src += (f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{")
src += "\n"
for meta in sorted(metas, key=lambda m: -m.num_specs):
cond_fn = lambda val, hint: f"({val} % {hint} == 0)" if hint == 16 else f"({val} == {hint})" if hint == 1 else None
conds = " && ".join([cond_fn(val, hint) for val, hint in zip(meta.arg_names, meta.sizes) if hint is not None])
src += f" if ({conds})\n"
cond_fn = ( #
lambda val, hint: f"({val} % {hint} == 0)" #
if hint == 16 #
else f"({val} == {hint})" #
if hint == 1 #
else None)
conds = " && ".join([ #
cond_fn(val, hint) #
for val, hint in zip(meta.arg_names, meta.sizes) #
if hint is not None
])
src += (f" if ({conds})\n" if any(meta.sizes) else "if (1)\n"
) # Edge case where no specializations hence no dispatching required
arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1]
src += f" return {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(stream, {', '.join(arg_names)});\n"
src += "\n"
@@ -183,7 +193,7 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -
src += f"void {mode}_{name}() {{"
src += "\n"
for meta in sorted(metas, key=lambda m: -m.num_specs):
src += f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n"
src += (f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n")
src += "}\n"
return src
@@ -252,7 +262,12 @@ if __name__ == "__main__":
help="Paths to header files to link. Must include linker directive annotations (autogenerated by ttc)",
)
parser.add_argument("--out", "-o", type=Path, help="Out filename")
parser.add_argument("--prefix", type=str, default="", help="String to prefix kernel dispatcher names")
parser.add_argument(
"--prefix",
type=str,
default="",
help="String to prefix kernel dispatcher names",
)
args = parser.parse_args()
# metadata