mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge pull request #410 from ROCmSoftwarePlatform/ifu-231117
Ifu 231117
This commit is contained in:
@@ -45,12 +45,12 @@ __all__ = [
|
||||
"tools",
|
||||
]
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
# misc. utilities that don't fit well
|
||||
# into any specific module
|
||||
# -------------------------------------
|
||||
|
||||
|
||||
def cdiv(x: int, y: int):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import importlib
|
||||
import importlib.util
|
||||
import os
|
||||
@@ -10,8 +10,12 @@ from typing import Dict
|
||||
|
||||
from ..runtime.driver import DriverBase
|
||||
|
||||
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
TRITON_VERSION = "2.1.0"
|
||||
|
||||
|
||||
class BaseBackend:
|
||||
|
||||
def __init__(self, device_type: str) -> None:
|
||||
self.device_type = device_type
|
||||
|
||||
@@ -104,7 +108,7 @@ def get_backend(device_type: str):
|
||||
def _path_to_binary(binary: str):
|
||||
base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
|
||||
paths = [
|
||||
os.environ.get("TRITON_PTXAS_PATH", ""),
|
||||
os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
|
||||
os.path.join(base_dir, "third_party", "cuda", "bin", binary)
|
||||
]
|
||||
|
||||
@@ -132,3 +136,48 @@ def path_to_cuobjdump():
|
||||
@functools.lru_cache()
|
||||
def path_to_nvdisasm():
|
||||
return _path_to_binary("nvdisasm")
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def compute_core_version_key():
|
||||
import pkgutil
|
||||
contents = []
|
||||
# frontend
|
||||
with open(__file__, "rb") as f:
|
||||
contents += [hashlib.sha1(f.read()).hexdigest()]
|
||||
# compiler
|
||||
compiler_path = os.path.join(TRITON_PATH, 'compiler')
|
||||
for lib in pkgutil.iter_modules([compiler_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.sha1(f.read()).hexdigest()]
|
||||
# backend
|
||||
libtriton_hash = hashlib.sha1()
|
||||
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
|
||||
while True:
|
||||
chunk = f.read(1024**2)
|
||||
if not chunk:
|
||||
break
|
||||
libtriton_hash.update(chunk)
|
||||
contents.append(libtriton_hash.hexdigest())
|
||||
# language
|
||||
language_path = os.path.join(TRITON_PATH, 'language')
|
||||
for lib in pkgutil.iter_modules([language_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.sha1(f.read()).hexdigest()]
|
||||
return '-'.join(TRITON_VERSION) + '-'.join(contents)
|
||||
|
||||
|
||||
_cached_cuda_version_key = None
|
||||
|
||||
|
||||
def get_cuda_version_key():
|
||||
global _cached_cuda_version_key
|
||||
if _cached_cuda_version_key is None:
|
||||
key = compute_core_version_key()
|
||||
try:
|
||||
ptxas = path_to_ptxas()[0]
|
||||
ptxas_version = subprocess.check_output([ptxas, "--version"])
|
||||
except RuntimeError:
|
||||
ptxas_version = b"NO_PTXAS"
|
||||
_cached_cuda_version_key = key + '-' + hashlib.sha1(ptxas_version).hexdigest()
|
||||
return _cached_cuda_version_key
|
||||
|
||||
@@ -92,9 +92,15 @@ def _build(name, src, srcdir):
|
||||
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
|
||||
|
||||
if is_hip():
|
||||
ret = subprocess.check_call([cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{hip_lib_dir}", "-lamdhip64", "-o", so])
|
||||
ret = subprocess.check_call([
|
||||
cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC",
|
||||
f"-L{hip_lib_dir}", "-lamdhip64", "-o", so
|
||||
])
|
||||
else:
|
||||
cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]
|
||||
cc_cmd = [
|
||||
cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda",
|
||||
"-o", so
|
||||
]
|
||||
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
|
||||
ret = subprocess.check_call(cc_cmd)
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from .compiler import (CompiledKernel, compile, get_arch_default_num_stages,
|
||||
get_arch_default_num_warps, instance_descriptor)
|
||||
from .compiler import (CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps,
|
||||
instance_descriptor)
|
||||
from .errors import CompilationError
|
||||
|
||||
__all__ = ["compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps", "get_arch_default_num_stages"]
|
||||
__all__ = [
|
||||
"compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps",
|
||||
"get_arch_default_num_stages"
|
||||
]
|
||||
|
||||
@@ -10,8 +10,7 @@ from .._C.libtriton.triton import ir
|
||||
from ..language import constexpr, tensor
|
||||
# ideally we wouldn't need any runtime component
|
||||
from ..runtime import JITFunction
|
||||
from .errors import (CompilationError, CompileTimeAssertionFailure,
|
||||
UnsupportedLanguageConstruct)
|
||||
from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
|
||||
|
||||
|
||||
def mangle_ty(ty):
|
||||
@@ -68,7 +67,10 @@ def _check_fn_args(node, fn, args):
|
||||
if fn.noinline:
|
||||
for idx, arg in enumerate(args):
|
||||
if not _is_constexpr(arg) and not _is_triton_scalar(arg):
|
||||
raise UnsupportedLanguageConstruct(fn.src, node, f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}')
|
||||
raise UnsupportedLanguageConstruct(
|
||||
fn.src, node,
|
||||
f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}'
|
||||
)
|
||||
|
||||
|
||||
def _get_fn_file_line(fn):
|
||||
@@ -89,6 +91,7 @@ _condition_types = {bool, int, type(None)} # Python types accepted for conditio
|
||||
|
||||
|
||||
class enter_sub_region:
|
||||
|
||||
def __init__(self, generator):
|
||||
self.generator = generator
|
||||
|
||||
@@ -109,6 +112,7 @@ class enter_sub_region:
|
||||
|
||||
# Check if the given syntax node has an "early" return
|
||||
class ContainsReturnChecker(ast.NodeVisitor):
|
||||
|
||||
def __init__(self, gscope):
|
||||
self.gscope = gscope
|
||||
|
||||
@@ -199,9 +203,10 @@ class ContainsReturnChecker(ast.NodeVisitor):
|
||||
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, function_name, target,
|
||||
module=None, is_kernel=False, function_types: Optional[Dict] = None,
|
||||
debug=False, noinline=False, file_name: Optional[str] = None, begin_line=0):
|
||||
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, function_name, target, module=None,
|
||||
is_kernel=False, function_types: Optional[Dict] = None, debug=False, noinline=False,
|
||||
file_name: Optional[str] = None, begin_line=0):
|
||||
self.context = context
|
||||
self.builder = ir.builder(context)
|
||||
self.file_name = file_name
|
||||
@@ -237,8 +242,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
))
|
||||
|
||||
def _define_name_lookup(self):
|
||||
|
||||
def local_lookup(name: str, absent):
|
||||
value = self.lscope.get(name, absent) # this needs to be re-fetched from `self` every time, because it gets switched occasionally
|
||||
# this needs to be re-fetched from `self` every time, because it gets switched occasionally
|
||||
value = self.lscope.get(name, absent)
|
||||
if value is not absent and name not in self.local_defs:
|
||||
self.global_uses[name] = value
|
||||
return value
|
||||
@@ -255,8 +262,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
return name_lookup
|
||||
|
||||
def set_value(self, name: str,
|
||||
value: Union[tensor, constexpr]) -> None:
|
||||
def set_value(self, name: str, value: Union[tensor, constexpr]) -> None:
|
||||
''' This function:
|
||||
called by visit_Assign() & visit_FunctionDef() to store left value (lvalue)
|
||||
1. record local defined name (FIXME: should consider control flow)
|
||||
@@ -338,7 +344,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.visit(init_node)
|
||||
# initialize function
|
||||
visibility = "public" if self.is_kernel else "private"
|
||||
self.fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility, self.noinline)
|
||||
self.fn = self.builder.get_or_insert_function(self.module, self.function_name,
|
||||
self.prototype.to_ir(self.builder), visibility, self.noinline)
|
||||
self.module.push_back(self.fn)
|
||||
entry = self.fn.add_entry_block()
|
||||
arg_values = []
|
||||
@@ -469,12 +476,23 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
rhs = self.visit(node.right)
|
||||
method_name = self._method_name_for_bin_op.get(type(node.op))
|
||||
if method_name is None:
|
||||
raise UnsupportedLanguageConstruct(None, node, "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
||||
return self._apply_binary_method(method_name, lhs, rhs)
|
||||
|
||||
_method_name_for_bin_op: Dict[Type[ast.operator], str] = {
|
||||
ast.Add: '__add__', ast.Sub: '__sub__', ast.Mult: '__mul__', ast.Div: '__truediv__',
|
||||
ast.FloorDiv: '__floordiv__', ast.Mod: '__mod__', ast.Pow: '__pow__',
|
||||
ast.LShift: '__lshift__', ast.RShift: '__rshift__', ast.BitAnd: '__and__', ast.BitOr: '__or__', ast.BitXor: '__xor__',
|
||||
ast.Add: '__add__',
|
||||
ast.Sub: '__sub__',
|
||||
ast.Mult: '__mul__',
|
||||
ast.Div: '__truediv__',
|
||||
ast.FloorDiv: '__floordiv__',
|
||||
ast.Mod: '__mod__',
|
||||
ast.Pow: '__pow__',
|
||||
ast.LShift: '__lshift__',
|
||||
ast.RShift: '__rshift__',
|
||||
ast.BitAnd: '__and__',
|
||||
ast.BitOr: '__or__',
|
||||
ast.BitXor: '__xor__',
|
||||
}
|
||||
|
||||
def visit_then_else_blocks(self, node, liveins, then_block, else_block):
|
||||
@@ -508,7 +526,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if name in then_defs or name in else_defs:
|
||||
names.append(name)
|
||||
ret_types.append(then_defs[name].type if name in then_defs else else_defs[name].type)
|
||||
ir_ret_types.append(then_defs[name].handle.get_type() if name in then_defs else else_defs[name].handle.get_type())
|
||||
ir_ret_types.append(then_defs[name].handle.get_type() if name in
|
||||
then_defs else else_defs[name].handle.get_type())
|
||||
# variable defined in then but not in else
|
||||
if name in then_defs and name not in else_defs:
|
||||
else_defs[name] = liveins[name]
|
||||
@@ -602,8 +621,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
contains_return = ContainsReturnChecker(self.gscope).visit(node)
|
||||
if self.scf_stack and contains_return:
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node,
|
||||
"Cannot have `return` statements inside `while` or `for` statements in triton "
|
||||
None, node, "Cannot have `return` statements inside `while` or `for` statements in triton "
|
||||
"(note that this also applies to `return` statements that are inside functions "
|
||||
"transitively called from within `while`/`for` statements)")
|
||||
elif self.scf_stack or not contains_return:
|
||||
@@ -612,10 +630,13 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.visit_if_top_level(cond, node)
|
||||
else:
|
||||
cond = _unwrap_if_constexpr(cond)
|
||||
if type(cond) not in _condition_types: # not isinstance - we insist the real thing, no subclasses and no ducks
|
||||
# not isinstance - we insist the real thing, no subclasses and no ducks
|
||||
if type(cond) not in _condition_types:
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
|
||||
', '.join(_.__name__ for _ in _condition_types), type(cond).__name__))
|
||||
None, node,
|
||||
"`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
|
||||
', '.join(_.__name__ for _ in _condition_types),
|
||||
type(cond).__name__))
|
||||
if cond:
|
||||
self.visit_compound_statement(node.body)
|
||||
else:
|
||||
@@ -624,15 +645,52 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def visit_IfExp(self, node):
|
||||
cond = self.visit(node.test)
|
||||
if _is_triton_tensor(cond):
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node,
|
||||
"Triton does not support `if` expressions (ternary operators) with dynamic conditions, use `if` statements instead")
|
||||
cond = cond.to(language.int1, _builder=self.builder)
|
||||
# TODO: Deal w/ more complicated return types (e.g tuple)
|
||||
with enter_sub_region(self):
|
||||
ip, last_loc = self._get_insertion_point_and_loc()
|
||||
|
||||
then_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(then_block)
|
||||
then_val = language.core._to_tensor(self.visit(node.body), self.builder)
|
||||
then_block = self.builder.get_insertion_block()
|
||||
|
||||
else_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(else_block)
|
||||
# do not need to reset lscope since
|
||||
# ternary expressions cannot define new variables
|
||||
else_val = language.core._to_tensor(self.visit(node.orelse), self.builder)
|
||||
else_block = self.builder.get_insertion_block()
|
||||
|
||||
self._set_insertion_point_and_loc(ip, last_loc)
|
||||
|
||||
assert then_val.type == else_val.type, \
|
||||
f'ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}'
|
||||
ret_type = then_val.type
|
||||
|
||||
ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else []
|
||||
if_op = self.builder.create_if_op(ret_type_ir, cond.handle, True)
|
||||
then_block.merge_block_before(if_op.get_then_block())
|
||||
if ret_type_ir:
|
||||
self.builder.set_insertion_point_to_end(if_op.get_then_block())
|
||||
self.builder.create_yield_op([then_val.handle])
|
||||
|
||||
self.builder.set_insertion_point_to_end(if_op.get_then_block())
|
||||
else_block.merge_block_before(if_op.get_else_block())
|
||||
if ret_type_ir:
|
||||
self.builder.set_insertion_point_to_end(if_op.get_else_block())
|
||||
self.builder.create_yield_op([else_val.handle])
|
||||
return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None
|
||||
else:
|
||||
cond = _unwrap_if_constexpr(cond)
|
||||
if type(cond) not in _condition_types: # not isinstance - we insist the real thing, no subclasses and no ducks
|
||||
|
||||
# not isinstance - we insist the real thing, no subclasses and no ducks
|
||||
if type(cond) not in _condition_types:
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
|
||||
', '.join(_.__name__ for _ in _condition_types), type(cond).__name__))
|
||||
None, node,
|
||||
"`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
|
||||
', '.join(_.__name__ for _ in _condition_types),
|
||||
type(cond).__name__))
|
||||
if cond:
|
||||
return self.visit(node.body)
|
||||
else:
|
||||
@@ -654,8 +712,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
return constexpr(lhs_value is not rhs_value)
|
||||
method_name = self._method_name_for_comp_op.get(type(node.ops[0]))
|
||||
if method_name is None:
|
||||
raise UnsupportedLanguageConstruct(None, node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__))
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__))
|
||||
return self._apply_binary_method(method_name, lhs, rhs)
|
||||
|
||||
_method_name_for_comp_op: Dict[Type[ast.cmpop], str] = {
|
||||
ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__'
|
||||
}
|
||||
@@ -664,11 +724,15 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
op = self.visit(node.operand)
|
||||
fn = self._method_name_for_unary_op.get(type(node.op))
|
||||
if fn is None:
|
||||
raise UnsupportedLanguageConstruct(None, node, "AST unary operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "AST unary operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
||||
if _is_triton_tensor(op):
|
||||
return getattr(op, fn)(_builder=self.builder)
|
||||
return getattr(op, fn)()
|
||||
_method_name_for_unary_op: Dict[Type[ast.unaryop], str] = {ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__'}
|
||||
|
||||
_method_name_for_unary_op: Dict[Type[ast.unaryop], str] = {
|
||||
ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__'
|
||||
}
|
||||
|
||||
def visit_While(self, node):
|
||||
with enter_sub_region(self) as sr:
|
||||
@@ -763,9 +827,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
iter_args = [self.visit(arg) for arg in node.iter.args]
|
||||
if IteratorClass == language.static_range:
|
||||
iterator = IteratorClass(*iter_args)
|
||||
static_range = range(iterator.start.value,
|
||||
iterator.end.value,
|
||||
iterator.step.value)
|
||||
static_range = range(iterator.start.value, iterator.end.value, iterator.step.value)
|
||||
for i in static_range:
|
||||
self.lscope[node.target.id] = constexpr(i)
|
||||
self.visit_compound_statement(node.body)
|
||||
@@ -902,8 +964,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def call_JitFunction(self, fn: JITFunction, args, kwargs):
|
||||
args = inspect.getcallargs(fn.fn, *args, **kwargs)
|
||||
args = [args[name] for name in fn.arg_names]
|
||||
args = [arg if _is_triton_tensor(arg)
|
||||
else constexpr(arg) for arg in args]
|
||||
args = [arg if _is_triton_tensor(arg) else constexpr(arg) for arg in args]
|
||||
# generate function def
|
||||
attributes = dict()
|
||||
constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)]
|
||||
@@ -921,8 +982,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
debug = self.debug if fn.debug is None else fn.debug
|
||||
file_name, begin_line = _get_fn_file_line(fn)
|
||||
generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module,
|
||||
function_name=fn_name, function_types=self.function_ret_types, debug=debug, noinline=fn.noinline,
|
||||
file_name=file_name, begin_line=begin_line, target=self.builder.target)
|
||||
function_name=fn_name, function_types=self.function_ret_types, debug=debug,
|
||||
noinline=fn.noinline, file_name=file_name, begin_line=begin_line,
|
||||
target=self.builder.target)
|
||||
generator.visit(fn.parse())
|
||||
callee_ret_type = generator.last_ret_type
|
||||
self.function_ret_types[fn_name] = callee_ret_type
|
||||
@@ -950,7 +1012,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
kws = dict(self.visit(keyword) for keyword in node.keywords)
|
||||
args = [self.visit(arg) for arg in node.args]
|
||||
if fn is language.core.device_assert: # TODO: this should not be so hardcoded
|
||||
if fn is language.core.device_assert: # TODO: this should not be so hardcoded
|
||||
if not self.debug:
|
||||
return
|
||||
if isinstance(fn, JITFunction):
|
||||
@@ -971,16 +1033,21 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
def visit_BoolOp(self, node: ast.BoolOp):
|
||||
if len(node.values) != 2:
|
||||
raise UnsupportedLanguageConstruct(None, node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.")
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node,
|
||||
"chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.")
|
||||
lhs = self.visit(node.values[0])
|
||||
rhs = self.visit(node.values[1])
|
||||
method_name = self._method_name_for_bool_op.get(type(node.op))
|
||||
if method_name is None:
|
||||
raise UnsupportedLanguageConstruct(None, node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
||||
return self._apply_binary_method(method_name, lhs, rhs)
|
||||
|
||||
_method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'}
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
|
||||
def visit_NameConstant(self, node):
|
||||
return constexpr(node.value)
|
||||
|
||||
@@ -1013,7 +1080,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
evaluated = self.visit(value.value)
|
||||
if not _is_constexpr(evaluated):
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type " + str(type(evaluated)))
|
||||
None, node,
|
||||
"Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type "
|
||||
+ str(type(evaluated)))
|
||||
values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value)
|
||||
else:
|
||||
raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value)))
|
||||
@@ -1055,7 +1124,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
passed = _unwrap_if_constexpr(self.visit(node.args[0]))
|
||||
if not isinstance(passed, bool):
|
||||
raise NotImplementedError("Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values")
|
||||
raise NotImplementedError(
|
||||
"Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values"
|
||||
)
|
||||
if not passed:
|
||||
if arg_count == 1:
|
||||
message = ""
|
||||
@@ -1144,10 +1215,9 @@ def ast_to_ttir(fn, signature, specialization, constants, debug, target):
|
||||
file_name, begin_line = _get_fn_file_line(fn)
|
||||
|
||||
prototype = language.function_type([], arg_types)
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants,
|
||||
function_name=function_name, attributes=new_attrs,
|
||||
is_kernel=True, debug=debug, file_name=file_name, begin_line=begin_line,
|
||||
target=target)
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name,
|
||||
attributes=new_attrs, is_kernel=True, debug=debug, file_name=file_name,
|
||||
begin_line=begin_line, target=target)
|
||||
try:
|
||||
generator.visit(fn.parse())
|
||||
except CompilationError as e:
|
||||
|
||||
@@ -11,25 +11,21 @@ from typing import Any
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs,
|
||||
compile_ptx_to_cubin, get_env_vars, get_num_warps,
|
||||
get_shared_memory_size, ir, runtime,
|
||||
translate_llvmir_to_ptx,
|
||||
from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs, compile_ptx_to_cubin, get_env_vars,
|
||||
get_num_warps, get_shared_memory_size, ir, runtime, translate_llvmir_to_ptx,
|
||||
translate_triton_gpu_to_llvmir)
|
||||
from ..common.backend import get_backend, path_to_ptxas
|
||||
from ..common.backend import get_backend, get_cuda_version_key, path_to_ptxas
|
||||
from ..common.build import is_hip
|
||||
# from ..runtime import driver, jit, JITFunction
|
||||
# TODO: runtime.errors
|
||||
from ..runtime.autotuner import OutOfResources
|
||||
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
|
||||
from ..runtime.driver import driver
|
||||
from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device,
|
||||
get_device_capability, version_key)
|
||||
from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device, get_device_capability)
|
||||
from ..tools.disasm import get_sass
|
||||
from .code_generator import ast_to_ttir
|
||||
from .make_launcher import make_stub
|
||||
from .utils import (InfoFromBackendForTensorMap, TensorMapManager,
|
||||
get_ids_of_tensormaps, parse_tma_info)
|
||||
from .utils import (InfoFromBackendForTensorMap, TensorMapManager, get_ids_of_tensormaps, parse_tma_info)
|
||||
|
||||
CUDA_DEFAULT_WARP_SIZE = 32
|
||||
|
||||
@@ -45,6 +41,7 @@ def _is_cuda(target):
|
||||
|
||||
|
||||
class LazyDict(dict):
|
||||
|
||||
def __getitem__(self, key):
|
||||
val = dict.__getitem__(self, key)
|
||||
if callable(val):
|
||||
@@ -102,8 +99,8 @@ def ttir_to_ttgir(mod, num_warps, warpsize, num_ctas, target):
|
||||
return mod
|
||||
|
||||
|
||||
def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target,
|
||||
cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, matrix_inst_type):
|
||||
def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization,
|
||||
enable_persistent, optimize_epilogue, matrix_inst_type):
|
||||
is_cuda = _is_cuda(target)
|
||||
if is_cuda:
|
||||
capability = target.capability
|
||||
@@ -173,6 +170,8 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target,
|
||||
if is_cuda and capability // 10 >= 9:
|
||||
pm.add_tritongpu_fence_insertion_pass()
|
||||
pm.add_tritongpu_ws_fixup_missing_attrs_pass()
|
||||
pm.add_tritongpu_optimize_thread_locality_pass()
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
@@ -196,6 +195,7 @@ def ttgir_to_llir(mod, extern_libs, target, tma_infos, waves_per_eu=0):
|
||||
|
||||
# PTX translation
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def ptx_get_version(cuda_version) -> int:
|
||||
'''
|
||||
@@ -260,7 +260,11 @@ def convert_type_repr(x):
|
||||
return x
|
||||
|
||||
|
||||
def make_hash(fn, target, env_vars, **kwargs):
|
||||
def make_hash(fn, target, env_vars, device_backend, **kwargs):
|
||||
if device_backend is None:
|
||||
version_key = get_cuda_version_key()
|
||||
else:
|
||||
version_key = device_backend.get_version_key()
|
||||
if isinstance(fn, JITFunction):
|
||||
configs = kwargs["configs"]
|
||||
signature = kwargs["signature"]
|
||||
@@ -274,16 +278,17 @@ def make_hash(fn, target, env_vars, **kwargs):
|
||||
enable_persistent = kwargs.get("enable_persistent", False)
|
||||
debug = kwargs.get("debug", False)
|
||||
# Get unique key for the compiled code
|
||||
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1), sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8))
|
||||
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1),
|
||||
sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8))
|
||||
configs_key = [get_conf_key(conf) for conf in configs]
|
||||
env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())]
|
||||
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{matrix_instr_nonkdim}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}"
|
||||
key = f"{fn.cache_key}-{version_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{matrix_instr_nonkdim}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}"
|
||||
return hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
assert isinstance(fn, str)
|
||||
ignore_version = kwargs.get('ignore_version', False)
|
||||
if (ignore_version):
|
||||
return hashlib.md5((Path(fn).read_text()).encode("utf-8")).hexdigest()
|
||||
return hashlib.md5((Path(fn).read_text() + version_key()).encode("utf-8")).hexdigest()
|
||||
return hashlib.md5((Path(fn).read_text() + version_key).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
|
||||
@@ -320,12 +325,14 @@ else:
|
||||
|
||||
|
||||
def _get_jsonable_constants(constants):
|
||||
|
||||
def _is_jsonable(x):
|
||||
try:
|
||||
json.dumps(x)
|
||||
return True
|
||||
except (TypeError, OverflowError):
|
||||
return False
|
||||
|
||||
serialized_constants = {}
|
||||
for constant in constants:
|
||||
if _is_jsonable(constants[constant]):
|
||||
@@ -340,7 +347,9 @@ def parse_mlir_module(path, context):
|
||||
return module
|
||||
|
||||
|
||||
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"], defaults=[set(), set(), set(), set()])
|
||||
instance_descriptor = namedtuple("instance_descriptor",
|
||||
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"],
|
||||
defaults=[set(), set(), set(), set()])
|
||||
|
||||
|
||||
def is_hip():
|
||||
@@ -382,10 +391,9 @@ def get_arch_default_num_stages(device_type, capability=None):
|
||||
|
||||
|
||||
def add_cuda_stages(target, extern_libs, stages):
|
||||
stages["ptx"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: llir_to_ptx(src, target))
|
||||
stages["cubin"] = (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ptx_to_cubin(src, target))
|
||||
|
||||
stages["ptx"] = (lambda path: Path(path).read_text(), lambda src: llir_to_ptx(src, target))
|
||||
stages["cubin"] = (lambda path: Path(path).read_bytes(), lambda src: ptx_to_cubin(src, target))
|
||||
|
||||
|
||||
def compile(fn, **kwargs):
|
||||
@@ -431,7 +439,8 @@ def compile(fn, **kwargs):
|
||||
# build architecture descriptor
|
||||
if device_type == "cuda":
|
||||
_device_backend = get_backend(device_type)
|
||||
target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps, enable_fp_fusion=enable_fp_fusion)
|
||||
target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps,
|
||||
enable_fp_fusion=enable_fp_fusion)
|
||||
else:
|
||||
_device_backend = get_backend(device_type)
|
||||
assert _device_backend
|
||||
@@ -440,11 +449,12 @@ def compile(fn, **kwargs):
|
||||
# build compilation stages
|
||||
stages = dict()
|
||||
stages["ast"] = (lambda path: fn, None)
|
||||
stages["ttir"] = (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target))
|
||||
stages["ttir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttir(
|
||||
ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target))
|
||||
if is_cuda:
|
||||
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
|
||||
stages["ttgir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttgir(
|
||||
ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info,
|
||||
enable_warp_specialization, enable_persistent, optimize_epilogue))
|
||||
stages["llir"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, target, tma_infos))
|
||||
add_cuda_stages(target, extern_libs, stages)
|
||||
@@ -504,18 +514,21 @@ def compile(fn, **kwargs):
|
||||
if ir_name == 'ttgir':
|
||||
num_warps_matches = re.findall(ttgir_num_warps_pattern, src)
|
||||
assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps"
|
||||
assert "num_warps" not in kwargs or int(num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile"
|
||||
assert "num_warps" not in kwargs or int(
|
||||
num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile"
|
||||
num_warps = int(num_warps_matches[0])
|
||||
param_tys = [convert_type_repr(ty) for ty in types]
|
||||
signature = {k: v for k, v in enumerate(param_tys)}
|
||||
first_stage = list(stages.keys()).index(ir_name)
|
||||
|
||||
# create cache manager
|
||||
fn_cache_manager = get_cache_manager(make_hash(fn, target, get_env_vars(), **kwargs))
|
||||
fn_cache_manager = get_cache_manager(make_hash(fn, target, get_env_vars(), _device_backend, **kwargs))
|
||||
# managers used to dump and override IR for debugging
|
||||
enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1"
|
||||
fn_override_manager = get_override_manager(make_hash(fn, target, get_env_vars(), **kwargs, ignore_version=True))
|
||||
fn_dump_manager = get_dump_manager(make_hash(fn, target, get_env_vars(), **kwargs, ignore_version=True))
|
||||
fn_override_manager = get_override_manager(
|
||||
make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True))
|
||||
fn_dump_manager = get_dump_manager(
|
||||
make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True))
|
||||
|
||||
# determine name and extension type of provided function
|
||||
if isinstance(fn, JITFunction):
|
||||
@@ -528,9 +541,7 @@ def compile(fn, **kwargs):
|
||||
metadata_filename = f"{name}.json"
|
||||
|
||||
# The group is addressed by the metadata
|
||||
metadata_group = fn_cache_manager.get_group(
|
||||
metadata_filename
|
||||
) or {}
|
||||
metadata_group = fn_cache_manager.get_group(metadata_filename) or {}
|
||||
|
||||
metadata_path = metadata_group.get(metadata_filename)
|
||||
|
||||
@@ -538,20 +549,21 @@ def compile(fn, **kwargs):
|
||||
with open(metadata_path) as f:
|
||||
metadata = json.load(f)
|
||||
if 'tensormaps_info' in metadata:
|
||||
metadata['tensormaps_info'] = [
|
||||
InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']]
|
||||
metadata['tensormaps_info'] = [InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']]
|
||||
else:
|
||||
metadata = {"num_warps": num_warps,
|
||||
"warp_size": warp_size,
|
||||
"num_ctas": num_ctas,
|
||||
"num_stages": num_stages,
|
||||
"waves_per_eu": waves_per_eu,
|
||||
"matrix_instr_nonkdim": matrix_instr_nonkdim,
|
||||
"enable_warp_specialization": enable_warp_specialization,
|
||||
"enable_persistent": enable_persistent,
|
||||
"constants": _get_jsonable_constants(constants),
|
||||
"debug": debug,
|
||||
"target": target, }
|
||||
metadata = {
|
||||
"num_warps": num_warps,
|
||||
"warp_size": warp_size,
|
||||
"num_ctas": num_ctas,
|
||||
"num_stages": num_stages,
|
||||
"waves_per_eu": waves_per_eu,
|
||||
"matrix_instr_nonkdim": matrix_instr_nonkdim,
|
||||
"enable_warp_specialization": enable_warp_specialization,
|
||||
"enable_persistent": enable_persistent,
|
||||
"constants": _get_jsonable_constants(constants),
|
||||
"debug": debug,
|
||||
"target": target,
|
||||
}
|
||||
metadata.update(get_env_vars())
|
||||
if ext == "ptx":
|
||||
assert "shared" in kwargs, "ptx compilation must provide shared memory size"
|
||||
@@ -623,10 +635,7 @@ def compile(fn, **kwargs):
|
||||
|
||||
ids_of_folded_args = tuple([int(k) for k in configs[0].ids_of_folded_args]) if isinstance(fn, JITFunction) else ()
|
||||
if "clusterDims" not in metadata:
|
||||
metadata["clusterDims"] = [
|
||||
cluster_info.clusterDimX,
|
||||
cluster_info.clusterDimY,
|
||||
cluster_info.clusterDimZ]
|
||||
metadata["clusterDims"] = [cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ]
|
||||
|
||||
if len(tma_infos) > 0:
|
||||
metadata["tensormaps_info"] = parse_tma_info(tma_infos, ids_of_folded_args)
|
||||
@@ -640,7 +649,10 @@ def compile(fn, **kwargs):
|
||||
fn.tensormaps_info = metadata["tensormaps_info"]
|
||||
|
||||
ids_of_const_exprs = tuple(fn.constexprs) if isinstance(fn, JITFunction) else ()
|
||||
ids = {"ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs": ids_of_const_exprs}
|
||||
ids = {
|
||||
"ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs":
|
||||
ids_of_const_exprs
|
||||
}
|
||||
# cache manager
|
||||
if is_cuda:
|
||||
so_path = make_stub(name, signature, constants, ids, enable_warp_specialization=enable_warp_specialization)
|
||||
@@ -648,7 +660,8 @@ def compile(fn, **kwargs):
|
||||
so_path = _device_backend.make_launcher_stub(name, signature, constants, ids)
|
||||
# write-back metadata, if it didn't come from the cache
|
||||
if metadata_path is None:
|
||||
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, binary=False)
|
||||
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename,
|
||||
binary=False)
|
||||
fn_cache_manager.put_group(metadata_filename, metadata_group)
|
||||
|
||||
# return handle to compiled kernel
|
||||
@@ -698,10 +711,7 @@ class CompiledKernel:
|
||||
|
||||
if self.device_type in ["cuda"]:
|
||||
device = get_current_device()
|
||||
bin_path = {
|
||||
driver.HIP: "hsaco_path",
|
||||
driver.CUDA: "cubin"
|
||||
}[driver.backend]
|
||||
bin_path = {driver.HIP: "hsaco_path", driver.CUDA: "cubin"}[driver.backend]
|
||||
max_shared = driver.utils.get_device_properties(device)["max_shared_mem"]
|
||||
fn_load_binary = driver.utils.load_binary
|
||||
else:
|
||||
@@ -749,4 +759,5 @@ class CompiledKernel:
|
||||
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.num_ctas, self.clusterDims[0],
|
||||
self.clusterDims[1], self.clusterDims[2], self.shared, stream, self.cu_function,
|
||||
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand)
|
||||
|
||||
return runner
|
||||
|
||||
@@ -3,9 +3,9 @@ import os
|
||||
import tempfile
|
||||
|
||||
from ..common import _build
|
||||
from ..common.backend import get_cuda_version_key
|
||||
from ..common.build import is_hip
|
||||
from ..runtime.cache import get_cache_manager
|
||||
from ..runtime.jit import version_key
|
||||
from .utils import generate_cu_signature
|
||||
|
||||
# ----- stub --------
|
||||
@@ -23,7 +23,7 @@ def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
|
||||
|
||||
def make_stub(name, signature, constants, ids, **kwargs):
|
||||
# name of files that are cached
|
||||
so_cache_key = make_so_cache_key(version_key(), signature, constants, ids, **kwargs)
|
||||
so_cache_key = make_so_cache_key(get_cuda_version_key(), signature, constants, ids, **kwargs)
|
||||
so_cache_manager = get_cache_manager(so_cache_key)
|
||||
so_name = f"{name}.so"
|
||||
# retrieve stub from cache if it exists
|
||||
@@ -40,6 +40,7 @@ def make_stub(name, signature, constants, ids, **kwargs):
|
||||
else:
|
||||
return cache_path
|
||||
|
||||
|
||||
# ----- source code generation --------
|
||||
|
||||
|
||||
@@ -100,7 +101,10 @@ def generate_launcher(constants, signature, ids):
|
||||
|
||||
# generate glue code
|
||||
folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']]
|
||||
params = [i for i in signature.keys() if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs)]
|
||||
params = [
|
||||
i for i in signature.keys()
|
||||
if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs)
|
||||
]
|
||||
src = f"""
|
||||
#include \"cuda.h\"
|
||||
#include <stdbool.h>
|
||||
|
||||
@@ -158,19 +158,21 @@ class InfoFromBackendForTensorMap:
|
||||
|
||||
# dtype:cuda.CUtensorMapDataType | int
|
||||
def bytes_from_type(self, dtype):
|
||||
return {driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT8"]: 1,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT16"]: 2,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT64"]: 8,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT64"]: 8,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT16"]: 2,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT64"]: 8,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_BFLOAT16"]: 2,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ"]: 4}[dtype]
|
||||
return {
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT8"]: 1,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT16"]: 2,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT64"]: 8,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT64"]: 8,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT16"]: 2,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT64"]: 8,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_BFLOAT16"]: 2,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ"]: 4
|
||||
}[dtype]
|
||||
|
||||
def getTensorMapDataType(self):
|
||||
return self.tensorDataType
|
||||
@@ -259,22 +261,29 @@ class InfoFromBackendForTensorMap:
|
||||
self.getInterleave(),
|
||||
self.getSwizzle(),
|
||||
self.getL2Promotion(),
|
||||
self.getOobFill()
|
||||
self.getOobFill(),
|
||||
)
|
||||
|
||||
# make hashable to use as partial key in cache
|
||||
def __hash__(self):
|
||||
return hash((self.ids_of_folded_args, self.globalAddressArgIdx, tuple(self.globalDimsArgIdx), tuple(self.globalStridesArgIdx), self.tensorDataType,
|
||||
self.tensorRank, tuple(self.boxDims), tuple(self.elementStrides), self.interleave, self.swizzle, self.l2Promotion, self.oobFill))
|
||||
return hash((self.ids_of_folded_args, self.globalAddressArgIdx, tuple(self.globalDimsArgIdx),
|
||||
tuple(self.globalStridesArgIdx), self.tensorDataType, self.tensorRank, tuple(self.boxDims),
|
||||
tuple(self.elementStrides), self.interleave, self.swizzle, self.l2Promotion, self.oobFill))
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
return (self.ids_of_folded_args, self.globalAddressArgIdx, self.globalDimsArgIdx, self.globalStridesArgIdx, self.tensorDataType, self.tensorRank, self.boxDims, self.elementStrides, self.interleave, self.swizzle, self.l2Promotion, self.oobFill) == (
|
||||
other.ids_of_folded_args, other.globalAddressArgIdx, other.globalDimsArgIdx, other.globalStridesArgIdx, other.tensorDataType, other.tensorRank, other.boxDims, other.elementStrides, other.interleave, other.swizzle, other.l2Promotion, other.oobFill)
|
||||
return (self.ids_of_folded_args, self.globalAddressArgIdx, self.globalDimsArgIdx, self.globalStridesArgIdx,
|
||||
self.tensorDataType, self.tensorRank, self.boxDims, self.elementStrides, self.interleave, self.swizzle,
|
||||
self.l2Promotion,
|
||||
self.oobFill) == (other.ids_of_folded_args, other.globalAddressArgIdx, other.globalDimsArgIdx,
|
||||
other.globalStridesArgIdx, other.tensorDataType, other.tensorRank, other.boxDims,
|
||||
other.elementStrides, other.interleave, other.swizzle, other.l2Promotion,
|
||||
other.oobFill)
|
||||
|
||||
|
||||
class TensorMapManager:
|
||||
|
||||
def __init__(self):
|
||||
self.tensormaps_device = {}
|
||||
|
||||
@@ -286,8 +295,7 @@ class TensorMapManager:
|
||||
t_tensormap = e.tensormap(args)
|
||||
TENSORMAP_SIZE_IN_BYTES = 128
|
||||
t_tensormap_device = driver.utils.cuMemAlloc(TENSORMAP_SIZE_IN_BYTES)
|
||||
driver.utils.cuMemcpyHtoD(
|
||||
t_tensormap_device, t_tensormap, TENSORMAP_SIZE_IN_BYTES)
|
||||
driver.utils.cuMemcpyHtoD(t_tensormap_device, t_tensormap, TENSORMAP_SIZE_IN_BYTES)
|
||||
self.tensormaps_device[key] = t_tensormap_device
|
||||
return int(self.tensormaps_device[key])
|
||||
|
||||
|
||||
@@ -111,7 +111,6 @@ from .random import (
|
||||
uint32_to_uniform_float,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TRITON_MAX_TENSOR_NUMEL",
|
||||
"abs",
|
||||
|
||||
@@ -22,10 +22,8 @@ def builtin(fn: T) -> T:
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if "_builder" not in kwargs or kwargs["_builder"] is None:
|
||||
raise ValueError(
|
||||
"Did you forget to add @triton.jit ? "
|
||||
"(`_builder` argument must be provided outside of JIT functions.)"
|
||||
)
|
||||
raise ValueError("Did you forget to add @triton.jit ? "
|
||||
"(`_builder` argument must be provided outside of JIT functions.)")
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
setattr(wrapper, TRITON_BUILTIN, True)
|
||||
@@ -54,7 +52,7 @@ def _to_tensor(x, builder):
|
||||
else:
|
||||
raise RuntimeError(f'Nonrepresentable integer {x}.')
|
||||
elif isinstance(x, float):
|
||||
min_float32 = 2 ** -126
|
||||
min_float32 = 2**-126
|
||||
max_float32 = (2 - 2**-23) * 2**127
|
||||
abs_x = __builtins__['abs'](x)
|
||||
if abs_x == float("inf") or\
|
||||
@@ -243,7 +241,7 @@ class dtype:
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.name,))
|
||||
return hash((self.name, ))
|
||||
|
||||
@property
|
||||
def scalar(self):
|
||||
@@ -297,6 +295,7 @@ class dtype:
|
||||
|
||||
|
||||
class pointer_type(dtype):
|
||||
|
||||
def __init__(self, element_ty: dtype, address_space: int = 1):
|
||||
if not isinstance(element_ty, dtype):
|
||||
raise TypeError('element_ty is a {type(element_ty).__name__}.')
|
||||
@@ -331,6 +330,7 @@ class pointer_type(dtype):
|
||||
|
||||
|
||||
class block_type(dtype):
|
||||
|
||||
def __init__(self, element_ty: dtype, shape: List):
|
||||
self.element_ty = element_ty
|
||||
|
||||
@@ -381,6 +381,7 @@ class block_type(dtype):
|
||||
|
||||
|
||||
class function_type(dtype):
|
||||
|
||||
def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None:
|
||||
self.ret_types = ret_types
|
||||
self.param_types = param_types
|
||||
@@ -531,7 +532,7 @@ class constexpr:
|
||||
return constexpr(~self.value)
|
||||
|
||||
def __pow__(self, other):
|
||||
return constexpr(self.value ** other.value)
|
||||
return constexpr(self.value**other.value)
|
||||
|
||||
def __rshift__(self, other):
|
||||
return constexpr(self.value >> other.value)
|
||||
@@ -547,6 +548,7 @@ class constexpr:
|
||||
|
||||
|
||||
class tensor:
|
||||
|
||||
def __init__(self, handle, type: dtype):
|
||||
# IR handle
|
||||
self.handle = handle
|
||||
@@ -740,11 +742,21 @@ class tensor:
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.equal(self, other, _builder)
|
||||
|
||||
@builtin
|
||||
def __req__(self, other, _builder=None):
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.equal(other, self, _builder)
|
||||
|
||||
@builtin
|
||||
def __ne__(self, other, _builder=None):
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.not_equal(self, other, _builder)
|
||||
|
||||
@builtin
|
||||
def __rne__(self, other, _builder=None):
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.not_equal(other, self, _builder)
|
||||
|
||||
@builtin
|
||||
def logical_and(self, other, _builder=None):
|
||||
other = _to_tensor(other, _builder)
|
||||
@@ -1023,6 +1035,7 @@ def expand_dims(input, axis, _builder=None):
|
||||
ret = semantic.expand_dims(ret, a, _builder)
|
||||
return ret
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Linear Algebra
|
||||
# -----------------------
|
||||
@@ -1171,6 +1184,7 @@ def advance(base: tensor, offsets, _builder=None):
|
||||
"""
|
||||
return semantic.advance(base, offsets, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Atomic Memory Operations
|
||||
# -----------------------
|
||||
@@ -1196,6 +1210,9 @@ def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]:
|
||||
:param sem: Memory semantics to use ("ACQUIRE_RELEASE" (default),
|
||||
"ACQUIRE", "RELEASE", or "RELAXED")
|
||||
:type sem: str
|
||||
:param scope: Scope of threads that observe synchronizing effect of the
|
||||
atomic operation ("GPU" (default), "CTA", or "SYSTEM")
|
||||
:type scope: str
|
||||
"""
|
||||
func.__doc__ = docstr
|
||||
return func
|
||||
@@ -1205,73 +1222,82 @@ def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]:
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("compare-and-swap", has_cmp=True)
|
||||
def atomic_cas(pointer, cmp, val, sem=None, _builder=None):
|
||||
def atomic_cas(pointer, cmp, val, sem=None, scope=None, _builder=None):
|
||||
cmp = _to_tensor(cmp, _builder)
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_cas(pointer, cmp, val, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_cas(pointer, cmp, val, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("exchange")
|
||||
def atomic_xchg(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_xchg(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_xchg(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("add")
|
||||
def atomic_add(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_add(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_add(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_add(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("max")
|
||||
def atomic_max(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_max(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_max(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_max(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("min")
|
||||
def atomic_min(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_min(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_min(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_min(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("logical and")
|
||||
def atomic_and(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_and(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_and(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_and(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("logical or")
|
||||
def atomic_or(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_or(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_or(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_or(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("logical xor")
|
||||
def atomic_xor(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_xor(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_xor(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Conditioning
|
||||
# -----------------------
|
||||
|
||||
|
||||
@builtin
|
||||
def where(condition, x, y, _builder=None):
|
||||
"""
|
||||
@@ -1299,6 +1325,7 @@ def where(condition, x, y, _builder=None):
|
||||
# Math
|
||||
# -----------------------
|
||||
|
||||
|
||||
@builtin
|
||||
def umulhi(x, y, _builder=None):
|
||||
"""
|
||||
@@ -1392,6 +1419,7 @@ def abs(x, _builder=None):
|
||||
# Reductions
|
||||
# -----------------------
|
||||
|
||||
|
||||
def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]:
|
||||
|
||||
def _decorator(func: T) -> T:
|
||||
@@ -1430,8 +1458,7 @@ def reduce(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
|
||||
"""
|
||||
if isinstance(input, tensor):
|
||||
return reduce((input,), axis, combine_fn,
|
||||
_builder=_builder, _generator=_generator)[0]
|
||||
return reduce((input, ), axis, combine_fn, _builder=_builder, _generator=_generator)[0]
|
||||
|
||||
def make_combine_region(reduce_op):
|
||||
in_scalar_tys = [t.type.scalar for t in input]
|
||||
@@ -1441,14 +1468,14 @@ def reduce(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
with _insertion_guard(_builder):
|
||||
param_types = [ty.to_ir(_builder) for ty in prototype.param_types]
|
||||
block = _builder.create_block_with_parent(region, param_types)
|
||||
args = [tensor(block.arg(i), ty)
|
||||
for i, ty in enumerate(prototype.param_types)]
|
||||
args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)]
|
||||
results = _generator.call_JitFunction(combine_fn, args, kwargs={})
|
||||
if isinstance(results, tensor):
|
||||
handles = [results.handle]
|
||||
else:
|
||||
handles = [r.handle for r in results]
|
||||
_builder.create_reduce_ret(*handles)
|
||||
|
||||
if axis is not None:
|
||||
axis = _constexpr_to_value(axis)
|
||||
return semantic.reduction(input, axis, make_combine_region, _builder)
|
||||
@@ -1483,8 +1510,7 @@ def _reduce_with_indices(input, axis, combine_fn, _builder=None, _generator=None
|
||||
index = expand_dims(index, axes_to_expand, _builder=_builder)
|
||||
index = broadcast_to(index, input.shape, _builder=_builder)
|
||||
|
||||
rvalue, rindices = reduce((input, index), axis, combine_fn,
|
||||
_builder=_builder, _generator=_generator)
|
||||
rvalue, rindices = reduce((input, index), axis, combine_fn, _builder=_builder, _generator=_generator)
|
||||
return rvalue, rindices
|
||||
|
||||
|
||||
@@ -1492,6 +1518,7 @@ def _reduce_with_indices(input, axis, combine_fn, _builder=None, _generator=None
|
||||
# Scans
|
||||
# -----------------------
|
||||
|
||||
|
||||
def _add_scan_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]:
|
||||
|
||||
def _decorator(func: T) -> T:
|
||||
@@ -1516,8 +1543,7 @@ def associative_scan(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
|
||||
"""
|
||||
if isinstance(input, tensor):
|
||||
return associative_scan((input,), axis, combine_fn,
|
||||
_builder=_builder, _generator=_generator)[0]
|
||||
return associative_scan((input, ), axis, combine_fn, _builder=_builder, _generator=_generator)[0]
|
||||
|
||||
def make_combine_region(scan_op):
|
||||
in_scalar_tys = [t.type.scalar for t in input]
|
||||
@@ -1527,17 +1553,18 @@ def associative_scan(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
with _insertion_guard(_builder):
|
||||
param_types = [ty.to_ir(_builder) for ty in prototype.param_types]
|
||||
block = _builder.create_block_with_parent(region, param_types)
|
||||
args = [tensor(block.arg(i), ty)
|
||||
for i, ty in enumerate(prototype.param_types)]
|
||||
args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)]
|
||||
results = _generator.call_JitFunction(combine_fn, args, kwargs={})
|
||||
if isinstance(results, tensor):
|
||||
handles = [results.handle]
|
||||
else:
|
||||
handles = [r.handle for r in results]
|
||||
_builder.create_scan_ret(*handles)
|
||||
|
||||
axis = _constexpr_to_value(axis)
|
||||
return semantic.associative_scan(input, axis, make_combine_region, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Compiler Hint Ops
|
||||
# -----------------------
|
||||
@@ -1600,6 +1627,8 @@ def max_constancy(input, values, _builder=None):
|
||||
raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
|
||||
values = [x.value for x in values]
|
||||
return semantic.max_constancy(input, values)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Debugging functions
|
||||
# -----------------------
|
||||
@@ -1739,12 +1768,12 @@ def inline_asm_elementwise(asm: str, constraints: str, args: list, dtype, is_pur
|
||||
broadcast_arg = dispatch_args[0]
|
||||
# Get the broadcast shape over all the arguments
|
||||
for i, item in enumerate(dispatch_args):
|
||||
_, broadcast_arg = semantic.binary_op_type_checking_impl(
|
||||
item, broadcast_arg, _builder, arithmetic_check=False)
|
||||
_, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder,
|
||||
arithmetic_check=False)
|
||||
# Change the shape of each argument based on the broadcast shape
|
||||
for i in range(len(dispatch_args)):
|
||||
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
|
||||
dispatch_args[i], broadcast_arg, _builder, arithmetic_check=False)
|
||||
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder,
|
||||
arithmetic_check=False)
|
||||
ret_shape = broadcast_arg.shape
|
||||
res_ty = block_type(dtype, ret_shape)
|
||||
call = _builder.create_inline_asm(asm, constraints, [t.handle for t in args], res_ty.to_ir(_builder), is_pure, pack)
|
||||
@@ -1757,7 +1786,6 @@ def inline_asm_elementwise(asm: str, constraints: str, args: list, dtype, is_pur
|
||||
|
||||
|
||||
class static_range:
|
||||
|
||||
"""
|
||||
Iterator that counts upward forever.
|
||||
|
||||
@@ -1801,7 +1829,9 @@ class static_range:
|
||||
# Extern functions
|
||||
# -----------------------
|
||||
|
||||
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, is_pure: bool, _builder=None):
|
||||
|
||||
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple,
|
||||
is_pure: bool, _builder=None):
|
||||
'''
|
||||
Dispatch a function to a library
|
||||
:param func: the function to dispatch
|
||||
@@ -1843,7 +1873,8 @@ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dic
|
||||
return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder), is_pure), ret_type)
|
||||
|
||||
|
||||
def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, _builder=None):
|
||||
def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool,
|
||||
_builder=None):
|
||||
'''
|
||||
Dispatch an elementwise function to a library
|
||||
:param lib_name: the name of the library
|
||||
@@ -1872,12 +1903,12 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol
|
||||
broadcast_arg = dispatch_args[0]
|
||||
# Get the broadcast shape over all the arguments
|
||||
for i, item in enumerate(dispatch_args):
|
||||
_, broadcast_arg = semantic.binary_op_type_checking_impl(
|
||||
item, broadcast_arg, _builder, arithmetic_check=arithmetic_check)
|
||||
_, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder,
|
||||
arithmetic_check=arithmetic_check)
|
||||
# Change the shape of each argument based on the broadcast shape
|
||||
for i in range(len(dispatch_args)):
|
||||
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
|
||||
dispatch_args[i], broadcast_arg, _builder, arithmetic_check=arithmetic_check)
|
||||
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder,
|
||||
arithmetic_check=arithmetic_check)
|
||||
if not all_scalar:
|
||||
ret_shape = broadcast_arg.shape
|
||||
func = getattr(_builder, "create_extern_elementwise")
|
||||
|
||||
@@ -3,16 +3,14 @@ from .. import core
|
||||
|
||||
@core.extern
|
||||
def globaltimer(_builder=None):
|
||||
return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [],
|
||||
dtype=core.int64, is_pure=False,
|
||||
pack=1, _builder=_builder)
|
||||
return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], dtype=core.int64, is_pure=False, pack=1,
|
||||
_builder=_builder)
|
||||
|
||||
|
||||
@core.extern
|
||||
def smid(_builder=None):
|
||||
return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [],
|
||||
dtype=core.int32, is_pure=True,
|
||||
pack=1, _builder=_builder)
|
||||
return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1,
|
||||
_builder=_builder)
|
||||
|
||||
|
||||
@core.builtin
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -91,6 +91,7 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
# two_to_the_minus_32: tl.constexpr = 2.328306e-10
|
||||
# return x * two_to_the_minus_32
|
||||
|
||||
|
||||
@jit
|
||||
def uint32_to_uniform_float(x):
|
||||
"""
|
||||
@@ -134,6 +135,7 @@ def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
u4 = uint32_to_uniform_float(i4)
|
||||
return u1, u2, u3, u4
|
||||
|
||||
|
||||
# -------------------
|
||||
# randn
|
||||
# -------------------
|
||||
|
||||
@@ -19,10 +19,12 @@ def _is_cuda(target):
|
||||
from ..compiler.compiler import CudaTargetDescriptor
|
||||
return isinstance(target, CudaTargetDescriptor)
|
||||
|
||||
|
||||
# Create custom exception that prints message "hello"
|
||||
|
||||
|
||||
class IncompatibleTypeErrorImpl(Exception):
|
||||
|
||||
def __init__(self, type_a, type_b):
|
||||
self.type_a = type_a
|
||||
self.type_b = type_b
|
||||
@@ -34,6 +36,7 @@ class IncompatibleTypeErrorImpl(Exception):
|
||||
# Programming Model
|
||||
# ===----------------------------------------------------------------------===##
|
||||
|
||||
|
||||
def program_id(axis: int, builder: ir.builder) -> tl.tensor:
|
||||
if axis not in (0, 1, 2):
|
||||
raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}")
|
||||
@@ -45,6 +48,7 @@ def num_programs(axis: int, builder: ir.builder) -> tl.tensor:
|
||||
raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}")
|
||||
return tl.tensor(builder.create_get_num_programs(axis), tl.int32)
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
# Implicit Casting Utilities
|
||||
# ===----------------------------------------------------------------------===//
|
||||
@@ -95,10 +99,12 @@ def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> t
|
||||
# 5 ) both operands are integer and undergo
|
||||
# integer promotion
|
||||
if div_or_mod and a_ty.int_signedness != b_ty.int_signedness:
|
||||
raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + " because they have different signedness;"
|
||||
raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() +
|
||||
" because they have different signedness;"
|
||||
"this is unlikely to result in a useful answer. Cast them to the same signedness.")
|
||||
return integer_promote_impl(a_ty, b_ty)
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
# Binary Operators
|
||||
# ===----------------------------------------------------------------------===//
|
||||
@@ -116,12 +122,9 @@ def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -
|
||||
raise IncompatibleTypeErrorImpl(type_a, type_b)
|
||||
|
||||
|
||||
def binary_op_type_checking_impl(lhs: tl.tensor,
|
||||
rhs: tl.tensor,
|
||||
builder: ir.builder,
|
||||
allow_lhs_ptr=False, allow_rhs_ptr=False,
|
||||
arithmetic_check=True, div_or_mod=False
|
||||
) -> Tuple[tl.tensor, tl.tensor]:
|
||||
def binary_op_type_checking_impl(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder, allow_lhs_ptr=False,
|
||||
allow_rhs_ptr=False, arithmetic_check=True,
|
||||
div_or_mod=False) -> Tuple[tl.tensor, tl.tensor]:
|
||||
# implicit broadcasting
|
||||
lhs, rhs = broadcast_impl_value(lhs, rhs, builder)
|
||||
# implicit typecasting
|
||||
@@ -136,9 +139,7 @@ def binary_op_type_checking_impl(lhs: tl.tensor,
|
||||
return lhs, rhs
|
||||
|
||||
|
||||
def add(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def add(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = binary_op_type_checking_impl(input, other, builder, True, True)
|
||||
input_scalar_ty = input.type.scalar
|
||||
other_scalar_ty = other.type.scalar
|
||||
@@ -162,15 +163,12 @@ def add(input: tl.tensor,
|
||||
assert False
|
||||
|
||||
|
||||
def sub(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def sub(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = binary_op_type_checking_impl(input, other, builder, True, False)
|
||||
scalar_ty = input.type.scalar
|
||||
# ptr - offset
|
||||
if scalar_ty.is_ptr():
|
||||
return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle),
|
||||
input.type)
|
||||
return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle), input.type)
|
||||
# float - float
|
||||
if scalar_ty.is_floating():
|
||||
return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type)
|
||||
@@ -180,9 +178,7 @@ def sub(input: tl.tensor,
|
||||
assert False
|
||||
|
||||
|
||||
def mul(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def mul(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = binary_op_type_checking_impl(input, other, builder)
|
||||
scalar_ty = input.type.scalar
|
||||
# float * float
|
||||
@@ -194,9 +190,7 @@ def mul(input: tl.tensor,
|
||||
assert False
|
||||
|
||||
|
||||
def truediv(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def truediv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
|
||||
input_scalar_ty = input.type.scalar
|
||||
other_scalar_ty = other.type.scalar
|
||||
@@ -222,9 +216,7 @@ def truediv(input: tl.tensor,
|
||||
return tl.tensor(builder.create_fdiv(input.handle, other.handle), input.type)
|
||||
|
||||
|
||||
def floordiv(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def floordiv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
|
||||
input_scalar_ty = input.type.scalar
|
||||
other_scalar_ty = other.type.scalar
|
||||
@@ -239,10 +231,7 @@ def floordiv(input: tl.tensor,
|
||||
assert False
|
||||
|
||||
|
||||
def fdiv(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
ieee_rounding: bool,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def fdiv(input: tl.tensor, other: tl.tensor, ieee_rounding: bool, builder: ir.builder) -> tl.tensor:
|
||||
input_scalar_ty = input.type.scalar
|
||||
other_scalar_ty = other.type.scalar
|
||||
if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating():
|
||||
@@ -252,18 +241,14 @@ def fdiv(input: tl.tensor,
|
||||
return tl.tensor(ret, input.type)
|
||||
|
||||
|
||||
def mod(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def mod(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
|
||||
scalar_ty = input.type.scalar
|
||||
other_scalar_ty = other.type.scalar
|
||||
# float % float
|
||||
if scalar_ty.is_floating():
|
||||
# input - input.div(other, rounding_mode="floor") * other
|
||||
ret = sub(input, mul(floor(fdiv(input, other, False, builder), builder),
|
||||
other, builder),
|
||||
builder)
|
||||
ret = sub(input, mul(floor(fdiv(input, other, False, builder), builder), other, builder), builder)
|
||||
return ret
|
||||
# % int
|
||||
elif scalar_ty.is_int():
|
||||
@@ -277,13 +262,13 @@ def mod(input: tl.tensor,
|
||||
return tl.tensor(builder.create_urem(input.handle, other.handle), input.type)
|
||||
assert False
|
||||
|
||||
|
||||
##############
|
||||
# bitwise ops
|
||||
##############
|
||||
|
||||
|
||||
def bitwise_op_type_checking_impl(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
def bitwise_op_type_checking_impl(input: tl.tensor, other: tl.tensor,
|
||||
builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]:
|
||||
input, other = binary_op_type_checking_impl(input, other, builder, False, False, False)
|
||||
input_sca_ty = input.type.scalar
|
||||
@@ -298,23 +283,17 @@ def bitwise_op_type_checking_impl(input: tl.tensor,
|
||||
return input, other
|
||||
|
||||
|
||||
def and_(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def and_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = bitwise_op_type_checking_impl(input, other, builder)
|
||||
return tl.tensor(builder.create_and(input.handle, other.handle), input.type)
|
||||
|
||||
|
||||
def or_(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def or_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = bitwise_op_type_checking_impl(input, other, builder)
|
||||
return tl.tensor(builder.create_or(input.handle, other.handle), input.type)
|
||||
|
||||
|
||||
def xor_(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def xor_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = bitwise_op_type_checking_impl(input, other, builder)
|
||||
return tl.tensor(builder.create_xor(input.handle, other.handle), input.type)
|
||||
|
||||
@@ -341,26 +320,21 @@ def not_(input: tl.tensor, builder: ir.builder):
|
||||
return invert(input, builder)
|
||||
|
||||
|
||||
def lshr(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def lshr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = bitwise_op_type_checking_impl(input, other, builder)
|
||||
return tl.tensor(builder.create_lshr(input.handle, other.handle), input.type)
|
||||
|
||||
|
||||
def ashr(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def ashr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = bitwise_op_type_checking_impl(input, other, builder)
|
||||
return tl.tensor(builder.create_ashr(input.handle, other.handle), input.type)
|
||||
|
||||
|
||||
def shl(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def shl(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = bitwise_op_type_checking_impl(input, other, builder)
|
||||
return tl.tensor(builder.create_shl(input.handle, other.handle), input.type)
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
# Unary Operators
|
||||
# ===----------------------------------------------------------------------===//
|
||||
@@ -370,8 +344,7 @@ def plus(input: tl.tensor) -> tl.tensor:
|
||||
return input
|
||||
|
||||
|
||||
def minus(input: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def minus(input: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input_sca_ty = input.type.scalar
|
||||
if input_sca_ty.is_ptr():
|
||||
raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")")
|
||||
@@ -379,8 +352,7 @@ def minus(input: tl.tensor,
|
||||
return sub(_0, input, builder)
|
||||
|
||||
|
||||
def invert(input: tl.tensor,
|
||||
builder: tl.tensor) -> tl.tensor:
|
||||
def invert(input: tl.tensor, builder: tl.tensor) -> tl.tensor:
|
||||
input_sca_ty = input.type.scalar
|
||||
if input_sca_ty.is_ptr() or input_sca_ty.is_floating():
|
||||
raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")")
|
||||
@@ -398,9 +370,7 @@ def _bool_like(v: tl.tensor) -> tl.block_type:
|
||||
return tl.block_type(tl.int1, shape)
|
||||
|
||||
|
||||
def greater_than(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def greater_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = binary_op_type_checking_impl(input, other, builder)
|
||||
scalar_ty = input.type.scalar
|
||||
# float > float
|
||||
@@ -415,9 +385,7 @@ def greater_than(input: tl.tensor,
|
||||
assert False
|
||||
|
||||
|
||||
def greater_equal(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def greater_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = binary_op_type_checking_impl(input, other, builder)
|
||||
scalar_ty = input.type.scalar
|
||||
# float >= float
|
||||
@@ -432,9 +400,7 @@ def greater_equal(input: tl.tensor,
|
||||
assert False
|
||||
|
||||
|
||||
def less_than(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def less_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = binary_op_type_checking_impl(input, other, builder)
|
||||
scalar_ty = input.type.scalar
|
||||
# float < float
|
||||
@@ -449,9 +415,7 @@ def less_than(input: tl.tensor,
|
||||
assert False
|
||||
|
||||
|
||||
def less_equal(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def less_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = binary_op_type_checking_impl(input, other, builder)
|
||||
scalar_ty = input.type.scalar
|
||||
# float < float
|
||||
@@ -466,9 +430,7 @@ def less_equal(input: tl.tensor,
|
||||
assert False
|
||||
|
||||
|
||||
def equal(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = binary_op_type_checking_impl(input, other, builder)
|
||||
scalar_ty = input.type.scalar
|
||||
# float == float
|
||||
@@ -480,9 +442,7 @@ def equal(input: tl.tensor,
|
||||
assert False
|
||||
|
||||
|
||||
def not_equal(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def not_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = binary_op_type_checking_impl(input, other, builder)
|
||||
scalar_ty = input.type.scalar
|
||||
# float == float
|
||||
@@ -493,6 +453,7 @@ def not_equal(input: tl.tensor,
|
||||
return tl.tensor(builder.create_icmpNE(input.handle, other.handle), _bool_like(input))
|
||||
assert False
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
# Block Creation
|
||||
# ===----------------------------------------------------------------------===//
|
||||
@@ -541,6 +502,7 @@ def ones(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
|
||||
# Shape Manipulation
|
||||
# ===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor:
|
||||
assert not value.type.is_block(), "Cannot splat a block tensor"
|
||||
if len(shape) == 0:
|
||||
@@ -549,9 +511,7 @@ def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor:
|
||||
return tl.tensor(builder.create_splat(value.handle, shape), ret_ty)
|
||||
|
||||
|
||||
def view(input: tl.tensor,
|
||||
dst_shape: List[int],
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def view(input: tl.tensor, dst_shape: List[int], builder: ir.builder) -> tl.tensor:
|
||||
# TODO: disable when TritonToTritonGPU handles views properly
|
||||
|
||||
# assert len(input.shape) == len(dst_shape)
|
||||
@@ -564,9 +524,7 @@ def view(input: tl.tensor,
|
||||
return tl.tensor(builder.create_view(input.handle, dst_shape), ret_ty)
|
||||
|
||||
|
||||
def reshape(input: tl.tensor,
|
||||
dst_shape: List[int],
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def reshape(input: tl.tensor, dst_shape: List[int], builder: ir.builder) -> tl.tensor:
|
||||
raise ValueError("`reshape` is not supported yet. Please use `view` instead if applicable. "
|
||||
"Note that view may reorder elements in an implementation- and context- dependent way.")
|
||||
|
||||
@@ -596,9 +554,7 @@ def trans(input: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
return tl.tensor(builder.create_trans(input.handle), ret_type)
|
||||
|
||||
|
||||
def broadcast_impl_shape(input: tl.tensor,
|
||||
shape: List[int],
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def broadcast_impl_shape(input: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor:
|
||||
if not input.type.is_block():
|
||||
ret_ty = tl.block_type(input.type, shape)
|
||||
return tl.tensor(builder.create_splat(input.handle, shape), ret_ty)
|
||||
@@ -616,9 +572,7 @@ def broadcast_impl_shape(input: tl.tensor,
|
||||
return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty)
|
||||
|
||||
|
||||
def broadcast_impl_value(lhs: tl.tensor,
|
||||
rhs: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def broadcast_impl_value(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
lhs_ty = lhs.type
|
||||
rhs_ty = rhs.type
|
||||
|
||||
@@ -638,13 +592,15 @@ def broadcast_impl_value(lhs: tl.tensor,
|
||||
if len(lhs_shape) < len(rhs_shape):
|
||||
# Add new axes to lhs
|
||||
for dim in range(len(lhs_shape), len(rhs_shape)):
|
||||
lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0), tl.block_type(lhs_ty.scalar, [1] + lhs_shape))
|
||||
lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0),
|
||||
tl.block_type(lhs_ty.scalar, [1] + lhs_shape))
|
||||
lhs_ty = lhs.type
|
||||
lhs_shape = lhs_ty.get_block_shapes()
|
||||
elif len(rhs_shape) < len(lhs_shape):
|
||||
# Add new axes to rhs
|
||||
for dim in range(len(rhs_shape), len(lhs_shape)):
|
||||
rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0), tl.block_type(rhs_ty.scalar, [1] + rhs_shape))
|
||||
rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0),
|
||||
tl.block_type(rhs_ty.scalar, [1] + rhs_shape))
|
||||
rhs_ty = rhs.type
|
||||
rhs_shape = rhs_ty.get_block_shapes()
|
||||
assert len(rhs_shape) == len(lhs_shape)
|
||||
@@ -670,14 +626,13 @@ def broadcast_impl_value(lhs: tl.tensor,
|
||||
# (scalar, scalar) => returns original blocks
|
||||
return lhs, rhs
|
||||
|
||||
|
||||
#######
|
||||
# cast
|
||||
#######
|
||||
|
||||
|
||||
def bitcast(input: tl.tensor,
|
||||
dst_ty: tl.dtype,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def bitcast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor:
|
||||
src_ty = input.type
|
||||
if src_ty.is_block():
|
||||
dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes())
|
||||
@@ -693,13 +648,10 @@ def bitcast(input: tl.tensor,
|
||||
if src_bits != dst_bits:
|
||||
raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to "
|
||||
"data-type of size " + str(dst_bits))
|
||||
return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||
|
||||
|
||||
def cast(input: tl.tensor,
|
||||
dst_ty: tl.dtype,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor:
|
||||
src_ty = input.type
|
||||
if isinstance(dst_ty, tl.constexpr):
|
||||
dst_ty = dst_ty.value
|
||||
@@ -718,8 +670,7 @@ def cast(input: tl.tensor,
|
||||
# Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64
|
||||
if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \
|
||||
(src_sca_ty.is_floating() and dst_sca_ty.is_fp8()):
|
||||
return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||
|
||||
# bf16 <=> (not fp32)
|
||||
if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \
|
||||
@@ -733,9 +684,7 @@ def cast(input: tl.tensor,
|
||||
dst_sca_ty.is_floating() and \
|
||||
src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth
|
||||
if truncate_fp:
|
||||
return tl.tensor(builder.create_fp_trunc(input.handle,
|
||||
dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||
|
||||
# Standard floating types' casting: extension
|
||||
# fp32 => fp64
|
||||
@@ -745,9 +694,7 @@ def cast(input: tl.tensor,
|
||||
dst_sca_ty.is_floating() and \
|
||||
src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth
|
||||
if ext_fp:
|
||||
return tl.tensor(builder.create_fp_ext(input.handle,
|
||||
dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||
|
||||
# Casting between integer types
|
||||
if src_sca_ty.is_int() and dst_sca_ty.is_int() and \
|
||||
@@ -758,9 +705,7 @@ def cast(input: tl.tensor,
|
||||
_0 = tl.tensor(builder.get_null_value(ty), input.dtype)
|
||||
return not_equal(input, _0, builder)
|
||||
else:
|
||||
return tl.tensor(builder.create_int_cast(input.handle,
|
||||
dst_ty.to_ir(builder), sign_extend),
|
||||
dst_ty)
|
||||
return tl.tensor(builder.create_int_cast(input.handle, dst_ty.to_ir(builder), sign_extend), dst_ty)
|
||||
|
||||
# Casting standard floating types to integer types
|
||||
if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int():
|
||||
@@ -769,35 +714,24 @@ def cast(input: tl.tensor,
|
||||
_0 = tl.tensor(builder.get_null_value(ty), input.dtype)
|
||||
return not_equal(input, _0, builder)
|
||||
elif dst_sca_ty.is_int_signed():
|
||||
return tl.tensor(builder.create_fp_to_si(input.handle,
|
||||
dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
return tl.tensor(builder.create_fp_to_si(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||
else:
|
||||
return tl.tensor(builder.create_fp_to_ui(input.handle,
|
||||
dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
return tl.tensor(builder.create_fp_to_ui(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||
|
||||
# Casting integer types to standard floating types
|
||||
if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating():
|
||||
if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed():
|
||||
return tl.tensor(builder.create_ui_to_fp(input.handle,
|
||||
dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
return tl.tensor(builder.create_ui_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||
else:
|
||||
return tl.tensor(builder.create_si_to_fp(input.handle,
|
||||
dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
return tl.tensor(builder.create_si_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||
|
||||
# Casting pointer types to integer types
|
||||
if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
|
||||
bitwidth = dst_sca_ty.int_bitwidth
|
||||
if bitwidth == 64:
|
||||
return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||
if bitwidth == 1:
|
||||
return not_equal(cast(input, tl.int64, builder),
|
||||
tl.tensor(builder.get_int64(0), tl.int64),
|
||||
builder)
|
||||
return not_equal(cast(input, tl.int64, builder), tl.tensor(builder.get_int64(0), tl.int64), builder)
|
||||
|
||||
# Casting integer types to pointer types
|
||||
if src_sca_ty.is_int() and dst_sca_ty.is_ptr():
|
||||
@@ -809,6 +743,7 @@ def cast(input: tl.tensor,
|
||||
|
||||
assert False, f'cannot cast {input} to {dst_ty}'
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
# Memory Operators
|
||||
# ===----------------------------------------------------------------------===//
|
||||
@@ -882,6 +817,20 @@ def _str_to_sem(sem_option):
|
||||
return sem
|
||||
|
||||
|
||||
def _str_to_scope(scope_option):
|
||||
scope = ir.MEM_SYNC_SCOPE.GPU
|
||||
if scope_option:
|
||||
if scope_option == "gpu":
|
||||
scope = ir.MEM_SYNC_SCOPE.GPU
|
||||
elif scope_option == "cta":
|
||||
scope = ir.MEM_SYNC_SCOPE.CTA
|
||||
elif scope_option == "sys":
|
||||
scope = ir.MEM_SYNC_SCOPE.SYSTEM
|
||||
else:
|
||||
raise ValueError(f"Memory semantic {scope_option} not supported")
|
||||
return scope
|
||||
|
||||
|
||||
def _canonicalize_boundary_check(boundary_check, block_shape):
|
||||
if boundary_check:
|
||||
if not hasattr(boundary_check, "__iter__"):
|
||||
@@ -913,8 +862,8 @@ def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, evicti
|
||||
boundary_check = _canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes())
|
||||
|
||||
# Build IR
|
||||
return tl.tensor(builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction,
|
||||
is_volatile), dst_ty)
|
||||
return tl.tensor(
|
||||
builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile), dst_ty)
|
||||
|
||||
|
||||
def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder):
|
||||
@@ -970,19 +919,13 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_
|
||||
if not mask:
|
||||
return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty)
|
||||
else:
|
||||
return tl.tensor(builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache,
|
||||
eviction, is_volatile), dst_ty)
|
||||
return tl.tensor(
|
||||
builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction,
|
||||
is_volatile), dst_ty)
|
||||
|
||||
|
||||
def load(ptr: tl.tensor,
|
||||
mask: Optional[tl.tensor],
|
||||
other: Optional[tl.tensor],
|
||||
boundary_check,
|
||||
padding_option: str,
|
||||
cache_modifier: str,
|
||||
eviction_policy: str,
|
||||
is_volatile: bool,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], boundary_check, padding_option: str,
|
||||
cache_modifier: str, eviction_policy: str, is_volatile: bool, builder: ir.builder) -> tl.tensor:
|
||||
# Cache, eviction and padding options
|
||||
cache = _str_to_load_cache_modifier(cache_modifier)
|
||||
eviction = _str_to_eviction_policy(eviction_policy)
|
||||
@@ -1007,7 +950,8 @@ def _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builde
|
||||
if not val.type.is_block():
|
||||
val = broadcast_impl_shape(val, block_shape, builder)
|
||||
assert val.type.is_block(), "Value argument must be block type or a scalar"
|
||||
assert block_shape == val.type.get_block_shapes(), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch"
|
||||
assert block_shape == val.type.get_block_shapes(
|
||||
), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch"
|
||||
assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch"
|
||||
|
||||
elt_ty = ptr.type.element_ty.element_ty
|
||||
@@ -1065,13 +1009,8 @@ def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder):
|
||||
return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction), tl.void)
|
||||
|
||||
|
||||
def store(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: Optional[tl.tensor],
|
||||
boundary_check,
|
||||
cache_modifier: str,
|
||||
eviction_policy: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def store(ptr: tl.tensor, val: tl.tensor, mask: Optional[tl.tensor], boundary_check, cache_modifier: str,
|
||||
eviction_policy: str, builder: ir.builder) -> tl.tensor:
|
||||
# Cache and eviction options
|
||||
cache = _str_to_store_cache_modifier(cache_modifier)
|
||||
eviction = _str_to_eviction_policy(eviction_policy)
|
||||
@@ -1089,22 +1028,16 @@ def store(ptr: tl.tensor,
|
||||
#########
|
||||
|
||||
|
||||
def atomic_cas(ptr: tl.tensor,
|
||||
cmp: tl.tensor,
|
||||
val: tl.tensor,
|
||||
sem: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def atomic_cas(ptr: tl.tensor, cmp: tl.tensor, val: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
|
||||
sem = _str_to_sem(sem)
|
||||
scope = _str_to_scope(scope)
|
||||
element_ty = ptr.type.scalar.element_ty
|
||||
if element_ty.primitive_bitwidth not in [16, 32, 64]:
|
||||
raise ValueError("atomic_cas only supports elements with width {16, 32, 64}")
|
||||
return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem), val.type)
|
||||
return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type)
|
||||
|
||||
|
||||
def atom_red_typechecking_impl(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
op: str,
|
||||
def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, op: str,
|
||||
builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]:
|
||||
if not ptr.type.scalar.is_ptr():
|
||||
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
|
||||
@@ -1129,30 +1062,19 @@ def atom_red_typechecking_impl(ptr: tl.tensor,
|
||||
return ptr, val, mask
|
||||
|
||||
|
||||
def atomic_max(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def atomic_max(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder)
|
||||
sem = _str_to_sem(sem)
|
||||
scope = _str_to_scope(scope)
|
||||
sca_ty = val.type.scalar
|
||||
# direct call to atomic_max for integers
|
||||
if sca_ty.is_int():
|
||||
if sca_ty.is_int_signed():
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX,
|
||||
ptr.handle,
|
||||
val.handle,
|
||||
mask.handle,
|
||||
sem),
|
||||
val.type)
|
||||
return tl.tensor(
|
||||
builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
|
||||
else:
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX,
|
||||
ptr.handle,
|
||||
val.handle,
|
||||
mask.handle,
|
||||
sem),
|
||||
val.type)
|
||||
return tl.tensor(
|
||||
builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
|
||||
# ROCM TODO: implement atomic_max/min for f32 as they are supported by MI cards.
|
||||
# for float
|
||||
# return atomic_smax(i_ptr, i_val) if val >= 0
|
||||
@@ -1167,36 +1089,29 @@ def atomic_max(ptr: tl.tensor,
|
||||
i_ptr = bitcast(ptr, tl.pointer_type(itype, 1), builder)
|
||||
pos = greater_equal(val, zero, builder)
|
||||
neg = less_than(val, zero, builder)
|
||||
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle, sem), i_val.type)
|
||||
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle, sem), i_val.type)
|
||||
pos_ret = tl.tensor(
|
||||
builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle,
|
||||
and_(mask, pos, builder).handle, sem, scope), i_val.type)
|
||||
neg_ret = tl.tensor(
|
||||
builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle,
|
||||
and_(mask, neg, builder).handle, sem, scope), i_val.type)
|
||||
ret = where(pos, pos_ret, neg_ret, builder)
|
||||
return bitcast(ret, sca_ty, builder)
|
||||
|
||||
|
||||
def atomic_min(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def atomic_min(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder)
|
||||
sem = _str_to_sem(sem)
|
||||
scope = _str_to_scope(scope)
|
||||
sca_ty = val.type.scalar
|
||||
# direct call to atomic_min for integers
|
||||
if sca_ty.is_int():
|
||||
if sca_ty.is_int_signed():
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN,
|
||||
ptr.handle,
|
||||
val.handle,
|
||||
mask.handle,
|
||||
sem),
|
||||
val.type)
|
||||
return tl.tensor(
|
||||
builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
|
||||
else:
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN,
|
||||
ptr.handle,
|
||||
val.handle,
|
||||
mask.handle,
|
||||
sem),
|
||||
val.type)
|
||||
return tl.tensor(
|
||||
builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
|
||||
# for float
|
||||
# return atomic_smin(i_ptr, i_val) if val >= 0
|
||||
# return atomic_umax(i_ptr, i_val) if val < 0
|
||||
@@ -1210,72 +1125,57 @@ def atomic_min(ptr: tl.tensor,
|
||||
i_ptr = bitcast(ptr, tl.pointer_type(itype, 1), builder)
|
||||
pos = greater_equal(val, zero, builder)
|
||||
neg = less_than(val, zero, builder)
|
||||
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN,
|
||||
i_ptr.handle,
|
||||
i_val.handle,
|
||||
and_(mask, pos, builder).handle,
|
||||
sem),
|
||||
i_val.type)
|
||||
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX,
|
||||
i_ptr.handle,
|
||||
i_val.handle,
|
||||
and_(mask, neg, builder).handle,
|
||||
sem),
|
||||
i_val.type)
|
||||
pos_ret = tl.tensor(
|
||||
builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, i_ptr.handle, i_val.handle,
|
||||
and_(mask, pos, builder).handle, sem, scope), i_val.type)
|
||||
neg_ret = tl.tensor(
|
||||
builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, i_ptr.handle, i_val.handle,
|
||||
and_(mask, neg, builder).handle, sem, scope), i_val.type)
|
||||
ret = where(pos, pos_ret, neg_ret, builder)
|
||||
return bitcast(ret, sca_ty, builder)
|
||||
|
||||
|
||||
def atomic_add(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def atomic_add(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder)
|
||||
sem = _str_to_sem(sem)
|
||||
scope = _str_to_scope(scope)
|
||||
sca_ty = val.type.scalar
|
||||
op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD
|
||||
return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem), val.type)
|
||||
return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
|
||||
|
||||
|
||||
def atomic_and(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def atomic_and(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder)
|
||||
sem = _str_to_sem(sem)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem), val.type)
|
||||
scope = _str_to_scope(scope)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope),
|
||||
val.type)
|
||||
|
||||
|
||||
def atomic_or(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def atomic_or(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder)
|
||||
sem = _str_to_sem(sem)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem), val.type)
|
||||
scope = _str_to_scope(scope)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope),
|
||||
val.type)
|
||||
|
||||
|
||||
def atomic_xor(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def atomic_xor(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder)
|
||||
sem = _str_to_sem(sem)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem), val.type)
|
||||
scope = _str_to_scope(scope)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope),
|
||||
val.type)
|
||||
|
||||
|
||||
def atomic_xchg(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
def atomic_xchg(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder)
|
||||
sem = _str_to_sem(sem)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem), val.type)
|
||||
scope = _str_to_scope(scope)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope),
|
||||
val.type)
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
# Linear Algebra
|
||||
@@ -1308,13 +1208,10 @@ def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty, target) -> bool:
|
||||
return False
|
||||
return True
|
||||
|
||||
def dot(lhs: tl.tensor,
|
||||
rhs: tl.tensor,
|
||||
acc: tl.tensor,
|
||||
allow_tf32: bool,
|
||||
max_num_imprecise_acc: int,
|
||||
out_dtype: tl.dtype,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
|
||||
def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_num_imprecise_acc: int,
|
||||
out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
|
||||
|
||||
def assert_dtypes_valid(lhs_dtype, rhs_dtype, target):
|
||||
# Checks for non-cuda archs
|
||||
if is_hip():
|
||||
@@ -1333,22 +1230,30 @@ def dot(lhs: tl.tensor,
|
||||
|
||||
# Checks for cuda archs
|
||||
if target.capability < 90:
|
||||
assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv(), "Dot op does not support fp8e4nv on CUDA arch < 90"
|
||||
assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv(
|
||||
), "Dot op does not support fp8e4nv on CUDA arch < 90"
|
||||
if lhs_dtype.is_fp8() and rhs_dtype.is_fp8():
|
||||
return
|
||||
assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!"
|
||||
else:
|
||||
assert not lhs_dtype.is_fp8e4b15() and not rhs_dtype.is_fp8e4b15(), "Dot op does not support fp8e4b15 on CUDA arch >= 90"
|
||||
assert not lhs_dtype.is_fp8e4b15x4() and not rhs_dtype.is_fp8e4b15x4(), "Dot op does not support fp8e4b15x4 on CUDA arch >= 90"
|
||||
assert not lhs_dtype.is_fp8e4b15() and not rhs_dtype.is_fp8e4b15(
|
||||
), "Dot op does not support fp8e4b15 on CUDA arch >= 90"
|
||||
assert not lhs_dtype.is_fp8e4b15x4() and not rhs_dtype.is_fp8e4b15x4(
|
||||
), "Dot op does not support fp8e4b15x4 on CUDA arch >= 90"
|
||||
if lhs_dtype.is_int() or rhs_dtype.is_int():
|
||||
assert lhs_dtype == rhs_dtype, f"Both operands must be same type. First operand ({lhs_dtype}) and second operand ({rhs_dtype})"
|
||||
assert lhs_dtype.is_int8() or lhs_dtype.is_uint8(), f"Both operands must be either int8 or uint8. Operand type ({lhs_dtype})"
|
||||
assert lhs_dtype.is_int8() or lhs_dtype.is_uint8(
|
||||
), f"Both operands must be either int8 or uint8. Operand type ({lhs_dtype})"
|
||||
elif lhs_dtype.is_fp8() or rhs_dtype.is_fp8():
|
||||
assert lhs_dtype.is_fp8e4nv() or lhs_dtype.is_fp8e5(), f"Only supports fp8e4nv or fp8e5. First operand ({lhs_dtype})"
|
||||
assert rhs_dtype.is_fp8e4nv() or rhs_dtype.is_fp8e5(), f"Only supports fp8e4nv or fp8e5. Second operand ({rhs_dtype})"
|
||||
assert lhs_dtype.is_fp8e4nv() or lhs_dtype.is_fp8e5(
|
||||
), f"Only supports fp8e4nv or fp8e5. First operand ({lhs_dtype})"
|
||||
assert rhs_dtype.is_fp8e4nv() or rhs_dtype.is_fp8e5(
|
||||
), f"Only supports fp8e4nv or fp8e5. Second operand ({rhs_dtype})"
|
||||
else:
|
||||
assert lhs_dtype.is_fp16() or lhs_dtype.is_bf16() or lhs_dtype.is_fp32() or lhs_dtype.is_int1(), f"Unsupported dtype {lhs_dtype}"
|
||||
assert rhs_dtype.is_fp16() or rhs_dtype.is_bf16() or rhs_dtype.is_fp32() or rhs_dtype.is_int1(), f"Unsupported dtype {rhs_dtype}"
|
||||
assert lhs_dtype.is_fp16() or lhs_dtype.is_bf16() or lhs_dtype.is_fp32() or lhs_dtype.is_int1(
|
||||
), f"Unsupported dtype {lhs_dtype}"
|
||||
assert rhs_dtype.is_fp16() or rhs_dtype.is_bf16() or rhs_dtype.is_fp32() or rhs_dtype.is_int1(
|
||||
), f"Unsupported dtype {rhs_dtype}"
|
||||
assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!"
|
||||
|
||||
assert lhs.type.is_block() and rhs.type.is_block()
|
||||
@@ -1389,7 +1294,8 @@ def dot(lhs: tl.tensor,
|
||||
_0 = builder.get_int32(0)
|
||||
ret_scalar_ty = tl.int32
|
||||
elif out_dtype.is_bf16():
|
||||
raise ValueError("out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`")
|
||||
raise ValueError(
|
||||
"out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`")
|
||||
elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16():
|
||||
_0 = builder.get_fp32(0)
|
||||
ret_scalar_ty = tl.float32
|
||||
@@ -1418,7 +1324,8 @@ def dot(lhs: tl.tensor,
|
||||
ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32, max_num_imprecise_acc),
|
||||
ret_ty)
|
||||
return cast(ret, ret_scalar_ty, builder)
|
||||
if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty, builder.target) and ret_scalar_ty.primitive_bitwidth <= 32:
|
||||
if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32,
|
||||
ret_scalar_ty, builder.target) and ret_scalar_ty.primitive_bitwidth <= 32:
|
||||
# max_num_imprecise_acc does not yet apply to hip
|
||||
if is_hip():
|
||||
max_num_imprecise_acc = 0
|
||||
@@ -1445,23 +1352,21 @@ def dot(lhs: tl.tensor,
|
||||
assert acc.type == ret_ty
|
||||
|
||||
# max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90
|
||||
if not (_is_cuda(builder.target) and builder.target.capability == 90 and lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and ret_scalar_ty.is_fp32()):
|
||||
if not (_is_cuda(builder.target) and builder.target.capability == 90 and lhs.dtype.is_fp8() and rhs.dtype.is_fp8()
|
||||
and ret_scalar_ty.is_fp32()):
|
||||
max_num_imprecise_acc = 0
|
||||
if max_num_imprecise_acc is None:
|
||||
max_num_imprecise_acc = 2**30
|
||||
|
||||
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, allow_tf32, max_num_imprecise_acc),
|
||||
ret_ty)
|
||||
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, allow_tf32, max_num_imprecise_acc), ret_ty)
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
# Indexing
|
||||
# ===----------------------------------------------------------------------===//
|
||||
|
||||
def where(condition: tl.tensor,
|
||||
x: tl.tensor,
|
||||
y: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
|
||||
def where(condition: tl.tensor, x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
condition = cast(condition, tl.int1, builder)
|
||||
if condition.type.is_block():
|
||||
condition, x = broadcast_impl_value(condition, x, builder)
|
||||
@@ -1474,14 +1379,13 @@ def where(condition: tl.tensor,
|
||||
ret_ty = x.type
|
||||
return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
# Reduction
|
||||
# ===----------------------------------------------------------------------===
|
||||
|
||||
|
||||
def reduction(
|
||||
inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder
|
||||
) -> Tuple[tl.tensor, ...]:
|
||||
def reduction(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder) -> Tuple[tl.tensor, ...]:
|
||||
if axis is None:
|
||||
new_inputs = []
|
||||
for i in range(len(inputs)):
|
||||
@@ -1507,10 +1411,7 @@ def reduction(
|
||||
region_builder_fn(reduce_op)
|
||||
reduce_op.verify()
|
||||
|
||||
return tuple(
|
||||
wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar)
|
||||
for i in range(len(inputs))
|
||||
)
|
||||
return tuple(wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar) for i in range(len(inputs)))
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===
|
||||
@@ -1518,9 +1419,8 @@ def reduction(
|
||||
# ===----------------------------------------------------------------------===
|
||||
|
||||
|
||||
def associative_scan(
|
||||
inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder
|
||||
) -> Tuple[tl.tensor, ...]:
|
||||
def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn,
|
||||
builder: ir.builder) -> Tuple[tl.tensor, ...]:
|
||||
if len(inputs) != 1:
|
||||
raise ValueError("Current implementation only support single tensor input")
|
||||
shape = inputs[0].type.shape
|
||||
@@ -1533,16 +1433,14 @@ def associative_scan(
|
||||
region_builder_fn(scan_op)
|
||||
scan_op.verify()
|
||||
|
||||
return tuple(
|
||||
wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar)
|
||||
for i in range(len(inputs))
|
||||
)
|
||||
return tuple(wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar) for i in range(len(inputs)))
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===
|
||||
# Math
|
||||
# ===----------------------------------------------------------------------===
|
||||
|
||||
|
||||
def _check_dtype(dtypes: List[str]) -> T:
|
||||
"""
|
||||
We're following libdevice's convention to check accepted data types for math functions.
|
||||
@@ -1551,7 +1449,9 @@ def _check_dtype(dtypes: List[str]) -> T:
|
||||
We should let the users know that they are using and invoke explicit cast to convert
|
||||
the data type to the supported one.
|
||||
"""
|
||||
|
||||
def wrapper(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def check(*args, **kwargs):
|
||||
# concatenate args and kwargs
|
||||
@@ -1560,6 +1460,7 @@ def _check_dtype(dtypes: List[str]) -> T:
|
||||
if arg.type.scalar.name not in dtypes:
|
||||
raise ValueError(f"Expected dtype {dtypes} but got {arg.type.scalar.name}")
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return check
|
||||
|
||||
return wrapper
|
||||
@@ -1645,6 +1546,15 @@ def debug_barrier(builder: ir.builder) -> tl.tensor:
|
||||
|
||||
|
||||
def device_print(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor:
|
||||
# It makes sense visually for prefix to end in ": "; make it so. Also,
|
||||
# non-empty prefixes should start with " ".
|
||||
if not prefix.endswith(" ") and args:
|
||||
prefix += " "
|
||||
if not prefix.endswith(": ") and args:
|
||||
prefix = prefix[:-1] + ": "
|
||||
if len(prefix) > 2 and not prefix.startswith(" "):
|
||||
prefix = " " + prefix
|
||||
|
||||
new_args = []
|
||||
for arg in args:
|
||||
new_args.append(arg.handle)
|
||||
@@ -1654,8 +1564,8 @@ def device_print(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.
|
||||
def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: int, builder: ir.builder) -> tl.tensor:
|
||||
cond_ty = cond.type
|
||||
if not cond_ty.is_block():
|
||||
cond_ty = tl.block_type(cond_ty.scalar, (1,))
|
||||
cond = tl.tensor(builder.create_splat(cond.handle, (1,)), cond_ty)
|
||||
cond_ty = tl.block_type(cond_ty.scalar, (1, ))
|
||||
cond = tl.tensor(builder.create_splat(cond.handle, (1, )), cond_ty)
|
||||
return tl.tensor(builder.create_assert(cond.handle, msg, file_name, func_name, lineno), tl.void)
|
||||
|
||||
|
||||
|
||||
@@ -123,6 +123,7 @@ def maximum(x, y):
|
||||
"""
|
||||
return math.max(x, y)
|
||||
|
||||
|
||||
# max and argmax
|
||||
|
||||
|
||||
@@ -149,8 +150,7 @@ def _argmax_combine_tie_break_fast(value1, index1, value2, index2):
|
||||
|
||||
|
||||
@jit
|
||||
@core._add_reduction_docstr("maximum",
|
||||
return_indices_arg="return_indices",
|
||||
@core._add_reduction_docstr("maximum", return_indices_arg="return_indices",
|
||||
tie_break_arg="return_indices_tie_break_left")
|
||||
def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True):
|
||||
input = core._promote_reduction_input(input)
|
||||
@@ -175,6 +175,7 @@ def argmax(input, axis, tie_break_left=True):
|
||||
(_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left)
|
||||
return ret
|
||||
|
||||
|
||||
# min and argmin
|
||||
|
||||
|
||||
@@ -201,8 +202,7 @@ def _argmin_combine_tie_break_fast(value1, index1, value2, index2):
|
||||
|
||||
|
||||
@jit
|
||||
@core._add_reduction_docstr("minimum",
|
||||
return_indices_arg="return_indices",
|
||||
@core._add_reduction_docstr("minimum", return_indices_arg="return_indices",
|
||||
tie_break_arg="return_indices_tie_break_left")
|
||||
def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True):
|
||||
input = core._promote_reduction_input(input)
|
||||
@@ -222,8 +222,7 @@ def min(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr
|
||||
|
||||
|
||||
@jit
|
||||
@core._add_reduction_docstr("minimum index",
|
||||
tie_break_arg="tie_break_left")
|
||||
@core._add_reduction_docstr("minimum index", tie_break_arg="tie_break_left")
|
||||
def argmin(input, axis, tie_break_left=True):
|
||||
_, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left)
|
||||
return ret
|
||||
@@ -233,6 +232,7 @@ def argmin(input, axis, tie_break_left=True):
|
||||
def _sum_combine(a, b):
|
||||
return a + b
|
||||
|
||||
|
||||
# sum
|
||||
|
||||
|
||||
@@ -247,6 +247,7 @@ def sum(input, axis=None):
|
||||
def _xor_combine(a, b):
|
||||
return a ^ b
|
||||
|
||||
|
||||
# xor sum
|
||||
|
||||
|
||||
@@ -258,8 +259,8 @@ def xor_sum(input, axis=None, _builder=None, _generator=None):
|
||||
raise ValueError("xor_sum only supported for integers")
|
||||
|
||||
input = core._promote_reduction_input(input, _builder=_builder)
|
||||
return core.reduce(input, axis, _xor_combine,
|
||||
_builder=_builder, _generator=_generator)
|
||||
return core.reduce(input, axis, _xor_combine, _builder=_builder, _generator=_generator)
|
||||
|
||||
|
||||
# cumsum
|
||||
|
||||
@@ -271,6 +272,7 @@ def cumsum(input, axis=0):
|
||||
input = core._promote_reduction_input(input)
|
||||
return core.associative_scan(input, axis, _sum_combine)
|
||||
|
||||
|
||||
# cumprod
|
||||
|
||||
|
||||
|
||||
@@ -17,15 +17,14 @@ from ... import language as tl
|
||||
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
|
||||
})
|
||||
@jit
|
||||
def _sdd_kernel(
|
||||
A, B, C,
|
||||
stride_za, stride_ha, stride_ma, stride_ak,
|
||||
stride_zb, stride_hb, stride_bk, stride_nb,
|
||||
stride_zc, stride_hc, stride_mc, stride_nc,
|
||||
K, grid_offset, lut,
|
||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
||||
BLOCK: tl.constexpr, EVEN_K: tl.constexpr
|
||||
):
|
||||
def _sdd_kernel(A, B, C, #
|
||||
stride_za, stride_ha, stride_ma, stride_ak, #
|
||||
stride_zb, stride_hb, stride_bk, stride_nb, #
|
||||
stride_zc, stride_hc, stride_mc, stride_nc, #
|
||||
K, grid_offset, lut, #
|
||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, #
|
||||
BLOCK: tl.constexpr, EVEN_K: tl.constexpr #
|
||||
):
|
||||
# ------------ #
|
||||
# - Prologue - #
|
||||
# ------------ #
|
||||
@@ -104,13 +103,13 @@ def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=
|
||||
c = out
|
||||
grid = [c.shape[1], 1, c.shape[0]]
|
||||
_sdd_kernel[grid](
|
||||
a, b, c,
|
||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
|
||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
||||
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
|
||||
Ka, 0, lut,
|
||||
TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4,
|
||||
num_warps=4,
|
||||
a, b, c, #
|
||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), #
|
||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), #
|
||||
c.stride(0), c.stride(1), c.stride(2), c.stride(3), #
|
||||
Ka, 0, lut, #
|
||||
TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4, #
|
||||
num_warps=4 #
|
||||
)
|
||||
return c
|
||||
|
||||
@@ -120,6 +119,7 @@ def sdd_lut(layout, block, device):
|
||||
lut = lut.contiguous()
|
||||
return lut, None
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Dense = Sparse x Dense (DSD)
|
||||
# This operation uses a look-up table that contains pre-computed pointer increments
|
||||
@@ -128,15 +128,14 @@ def sdd_lut(layout, block, device):
|
||||
|
||||
|
||||
@jit
|
||||
def _dsd_kernel(
|
||||
A, B, C,
|
||||
stride_az, stride_ha, stride_am, stride_ak,
|
||||
stride_zb, stride_hb, stride_bk, stride_bn,
|
||||
stride_zc, stride_hc, stride_cm, stride_cn,
|
||||
DS0, DS1, lut,
|
||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr
|
||||
):
|
||||
def _dsd_kernel(A, B, C, #
|
||||
stride_az, stride_ha, stride_am, stride_ak, #
|
||||
stride_zb, stride_hb, stride_bk, stride_bn, #
|
||||
stride_zc, stride_hc, stride_cm, stride_cn, #
|
||||
DS0, DS1, lut, #
|
||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, #
|
||||
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr #
|
||||
):
|
||||
# ------------ #
|
||||
# - Prologue - #
|
||||
# ------------ #
|
||||
@@ -229,13 +228,13 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=N
|
||||
# compute output
|
||||
grid = lambda meta: [cdiv(BS3, meta['TILE_N']), width, BS0]
|
||||
_dsd_kernel[grid](
|
||||
a, b, c,
|
||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
|
||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
||||
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
|
||||
BS3, AS1, lut,
|
||||
TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
|
||||
num_warps=4, GROUP_SIZE_M=4,
|
||||
a, b, c, #
|
||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), #
|
||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), #
|
||||
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), #
|
||||
BS3, AS1, lut, #
|
||||
TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4, #
|
||||
num_warps=4, GROUP_SIZE_M=4 #
|
||||
)
|
||||
# exit()
|
||||
return c
|
||||
@@ -337,6 +336,7 @@ def dsd_lut(layout, block, step, trans, device):
|
||||
# create locks
|
||||
return lut, width
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Dense = Dense x Sparse (DDS)
|
||||
# -----------------------------
|
||||
@@ -346,6 +346,7 @@ def dsd_lut(layout, block, step, trans, device):
|
||||
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
|
||||
return dsd_matmul(b, a, not trans_b, not trans_a, not trans_c, spdims, block, lut, width, out=out)
|
||||
|
||||
|
||||
##############
|
||||
# MAIN API #
|
||||
##############
|
||||
@@ -356,10 +357,8 @@ class _matmul(torch.autograd.Function):
|
||||
fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul}
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block,
|
||||
c_lut, c_width, da_lut, da_width, db_lut, db_width, out
|
||||
):
|
||||
def forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_width, da_lut, da_width, db_lut,
|
||||
db_width, out):
|
||||
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out)
|
||||
# save for backward
|
||||
ctx.save_for_backward(a, b)
|
||||
@@ -385,15 +384,13 @@ class _matmul(torch.autograd.Function):
|
||||
# gradients w.r.t. a
|
||||
if ctx.needs_input_grad[0]:
|
||||
mode_da = mode[1] + mode[0] + mode[2]
|
||||
da = _matmul.fn[mode_da](
|
||||
dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_width,
|
||||
)
|
||||
da = _matmul.fn[mode_da](dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block,
|
||||
ctx.da_lut, ctx.da_width)
|
||||
# gradients w.r.t. b
|
||||
if ctx.needs_input_grad[1]:
|
||||
mode_db = mode[2] + mode[1] + mode[0]
|
||||
db = _matmul.fn[mode_db](
|
||||
a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_width,
|
||||
)
|
||||
db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block,
|
||||
ctx.db_lut, ctx.db_width)
|
||||
dout = dc if ctx.has_out else None
|
||||
return da, db, None, None, None, \
|
||||
None, None, None, None, \
|
||||
@@ -427,11 +424,9 @@ class matmul:
|
||||
self.db_lut, self.db_width = sdd_lut(layout, block, device)
|
||||
|
||||
def __call__(self, a, b, out=None):
|
||||
c = _matmul.apply(
|
||||
a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block,
|
||||
self.c_lut, self.c_width,
|
||||
self.da_lut, self.da_width,
|
||||
self.db_lut, self.db_width,
|
||||
out
|
||||
)
|
||||
c = _matmul.apply(a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, #
|
||||
self.c_lut, self.c_width, #
|
||||
self.da_lut, self.da_width, #
|
||||
self.db_lut, self.db_width, #
|
||||
out)
|
||||
return c
|
||||
|
||||
@@ -18,14 +18,13 @@ def num_warps(n):
|
||||
|
||||
|
||||
@jit
|
||||
def _blocksparse_softmax_fwd(
|
||||
Out, A, stride_xz, LUT,
|
||||
R, extent, stride_zr, stride_hr, # relative attention
|
||||
scale, is_causal,
|
||||
ROW_SIZE: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
IS_DENSE: tl.constexpr,
|
||||
):
|
||||
def _blocksparse_softmax_fwd(Out, A, stride_xz, LUT, #
|
||||
R, extent, stride_zr, stride_hr, # relative attention
|
||||
scale, is_causal, #
|
||||
ROW_SIZE: tl.constexpr, #
|
||||
BLOCK_SIZE: tl.constexpr, #
|
||||
IS_DENSE: tl.constexpr #
|
||||
):
|
||||
h = tl.program_id(0)
|
||||
m = tl.program_id(1)
|
||||
z = tl.program_id(2)
|
||||
@@ -73,18 +72,16 @@ def _blocksparse_softmax_fwd(
|
||||
|
||||
|
||||
@jit
|
||||
def _blocksparse_softmax_bwd(
|
||||
DA, stride_zdx,
|
||||
DOut, stride_zdout,
|
||||
Out, stride_zout,
|
||||
scale,
|
||||
LUT,
|
||||
DR, extent, stride_zr, stride_hr, stride_er,
|
||||
is_causal,
|
||||
ROW_SIZE: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
IS_DENSE: tl.constexpr,
|
||||
):
|
||||
def _blocksparse_softmax_bwd(DA, stride_zdx, #
|
||||
DOut, stride_zdout, #
|
||||
Out, stride_zout, #
|
||||
scale, #
|
||||
LUT, #
|
||||
DR, extent, stride_zr, stride_hr, stride_er, #
|
||||
is_causal, #
|
||||
ROW_SIZE: tl.constexpr, #
|
||||
BLOCK_SIZE: tl.constexpr, #
|
||||
IS_DENSE: tl.constexpr):
|
||||
h = tl.program_id(0)
|
||||
m = tl.program_id(1)
|
||||
z = tl.program_id(2)
|
||||
@@ -133,6 +130,7 @@ def _blocksparse_softmax_bwd(
|
||||
|
||||
|
||||
class _softmax(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def make_lut(layout, block, device):
|
||||
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
|
||||
@@ -151,10 +149,7 @@ class _softmax(torch.autograd.Function):
|
||||
return lut, int(total_sizes.max())
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, a, scale, rel_logits, is_causal,
|
||||
spdims, block, lut, maxlut, is_dense
|
||||
):
|
||||
def forward(ctx, a, scale, rel_logits, is_causal, spdims, block, lut, maxlut, is_dense):
|
||||
if scale is not None and isinstance(scale, torch.Tensor):
|
||||
assert scale.device.type == "cpu"
|
||||
scale = scale.item()
|
||||
@@ -165,14 +160,14 @@ class _softmax(torch.autograd.Function):
|
||||
# enqueue kernel
|
||||
out = torch.empty_like(a)
|
||||
_blocksparse_softmax_fwd[grid](
|
||||
out, a, a.stride(0), lut,
|
||||
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn
|
||||
scale,
|
||||
is_causal,
|
||||
BLOCK_SIZE=block,
|
||||
ROW_SIZE=next_power_of_2(maxlut),
|
||||
IS_DENSE=is_dense,
|
||||
num_warps=num_warps(maxlut)
|
||||
out, a, a.stride(0), lut, #
|
||||
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn#
|
||||
scale, #
|
||||
is_causal, #
|
||||
BLOCK_SIZE=block, #
|
||||
ROW_SIZE=next_power_of_2(maxlut), #
|
||||
IS_DENSE=is_dense, #
|
||||
num_warps=num_warps(maxlut) #
|
||||
)
|
||||
# save to context
|
||||
# ctx.mark_dirty(x)
|
||||
@@ -201,28 +196,23 @@ class _softmax(torch.autograd.Function):
|
||||
grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M)
|
||||
da = torch.empty_like(dout)
|
||||
_blocksparse_softmax_bwd[grid](
|
||||
da, da.stride(0),
|
||||
dout, dout.stride(0),
|
||||
out, out.stride(0),
|
||||
ctx.scale,
|
||||
lut,
|
||||
dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2],
|
||||
ctx.is_causal,
|
||||
BLOCK_SIZE=ctx.block,
|
||||
ROW_SIZE=next_power_of_2(ctx.maxlut),
|
||||
IS_DENSE=ctx.is_dense,
|
||||
num_warps=num_warps(ctx.maxlut)
|
||||
da, da.stride(0), #
|
||||
dout, dout.stride(0), #
|
||||
out, out.stride(0), #
|
||||
ctx.scale, #
|
||||
lut, #
|
||||
dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2], #
|
||||
ctx.is_causal, #
|
||||
BLOCK_SIZE=ctx.block, #
|
||||
ROW_SIZE=next_power_of_2(ctx.maxlut), #
|
||||
IS_DENSE=ctx.is_dense, #
|
||||
num_warps=num_warps(ctx.maxlut) #
|
||||
)
|
||||
return (da, None, None, dr, None,
|
||||
None, None, None, None, None,
|
||||
None,
|
||||
None, None, None,
|
||||
None,
|
||||
None, None, None
|
||||
)
|
||||
return (da, None, None, dr, None, None, None, None, None, None, None, None, None, None, None, None, None, None)
|
||||
|
||||
|
||||
class softmax:
|
||||
|
||||
def __init__(self, layout, block, device, is_dense=False):
|
||||
self.spdims = layout.shape
|
||||
self.layout = layout
|
||||
@@ -233,8 +223,6 @@ class softmax:
|
||||
def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False):
|
||||
if rel_logits is not None and rel_logits.dtype != a.dtype:
|
||||
raise ValueError(f"relative position embedding must be {a.dtype}")
|
||||
a = _softmax.apply(
|
||||
a, scale, rel_logits, is_causal,
|
||||
self.spdims, self.block, self.lut, self.maxlut, self.is_dense,
|
||||
)
|
||||
a = _softmax.apply(a, scale, rel_logits, is_causal, self.spdims, self.block, self.lut, self.maxlut,
|
||||
self.is_dense)
|
||||
return a
|
||||
|
||||
@@ -59,6 +59,7 @@ def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
|
||||
|
||||
|
||||
class _cross_entropy(torch.autograd.Function):
|
||||
|
||||
@classmethod
|
||||
def forward(cls, ctx, logits, indices):
|
||||
# make sure we can use triton
|
||||
|
||||
@@ -15,20 +15,19 @@ from .. import language as tl
|
||||
|
||||
|
||||
@jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
L,
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, H, N_CTX,
|
||||
Z_H_N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
):
|
||||
def _fwd_kernel(Q, K, V, sm_scale, #
|
||||
L, #
|
||||
Out, #
|
||||
stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vn, stride_vk, #
|
||||
stride_oz, stride_oh, stride_om, stride_on, #
|
||||
Z, H, N_CTX, #
|
||||
Z_H_N_CTX, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
|
||||
BLOCK_N: tl.constexpr, #
|
||||
IS_CAUSAL: tl.constexpr #
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
qvk_offset = off_hz * stride_qh
|
||||
@@ -40,7 +39,7 @@ def _fwd_kernel(
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, vk_offset),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
order=(0, 1)
|
||||
order=(0, 1),
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V,
|
||||
@@ -48,7 +47,7 @@ def _fwd_kernel(
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(vk_offset, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
order=(1, 0),
|
||||
)
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
@@ -104,7 +103,7 @@ def _fwd_kernel(
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(vk_offset + start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
order=(1, 0),
|
||||
)
|
||||
# O_ptrs = Out + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk
|
||||
tl.store(O_block_ptr, acc.to(K.dtype.element_ty))
|
||||
@@ -112,9 +111,11 @@ def _fwd_kernel(
|
||||
|
||||
@jit
|
||||
def _bwd_preprocess(
|
||||
Out, DO,
|
||||
Out,
|
||||
DO,
|
||||
Delta,
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
D_HEAD: tl.constexpr,
|
||||
):
|
||||
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, D_HEAD)
|
||||
@@ -128,40 +129,48 @@ def _bwd_preprocess(
|
||||
|
||||
|
||||
@jit
|
||||
def _bwd_kernel_one_col_block(
|
||||
Q, K, V, sm_scale, qk_scale,
|
||||
Out, DO,
|
||||
DQ, DK, DV,
|
||||
L,
|
||||
D,
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
Z, H, N_CTX,
|
||||
off_hz, start_n, num_block,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
SEQUENCE_PARALLEL: tl.constexpr,
|
||||
CAUSAL: tl.constexpr,
|
||||
MMA_V3: tl.constexpr
|
||||
):
|
||||
if SEQUENCE_PARALLEL:
|
||||
DQ += stride_dqa.to(tl.int64) * start_n
|
||||
def _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, #
|
||||
Out, DO, #
|
||||
DQ, DK, DV, #
|
||||
L, #
|
||||
D, #
|
||||
Q_block_ptr, K_block_ptr, V_block_ptr, #
|
||||
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vn, stride_vk, #
|
||||
Z, H, N_CTX, #
|
||||
off_h, off_z, off_hz, start_n, num_block, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
|
||||
BLOCK_N: tl.constexpr, #
|
||||
SEQUENCE_PARALLEL: tl.constexpr, #
|
||||
CAUSAL: tl.constexpr, #
|
||||
MMA_V3: tl.constexpr #
|
||||
):
|
||||
if CAUSAL:
|
||||
lo = start_n * BLOCK_M
|
||||
else:
|
||||
lo = 0
|
||||
|
||||
Q_offset = (off_z * stride_qz + off_h * stride_qh) // stride_qm
|
||||
DQ_offset = off_z * stride_qz + off_h * stride_qh
|
||||
K_offset = (off_z * stride_kz + off_h * stride_kh) // stride_kn
|
||||
V_offset = (off_z * stride_vz + off_h * stride_vh) // stride_vn
|
||||
if SEQUENCE_PARALLEL:
|
||||
DQ_offset += stride_dqa.to(tl.int64) * start_n
|
||||
DQ_offset = DQ_offset // stride_qm
|
||||
|
||||
Q_block_ptr = tl.advance(Q_block_ptr, (lo + Q_offset, 0))
|
||||
K_block_ptr = tl.advance(K_block_ptr, (start_n * BLOCK_M + K_offset, 0))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (start_n * BLOCK_M + V_offset, 0))
|
||||
DO_block_ptr = tl.advance(DO_block_ptr, (lo + Q_offset, 0))
|
||||
DQ_block_ptr = tl.advance(DQ_block_ptr, (lo + DQ_offset, 0))
|
||||
DK_block_ptr = tl.advance(DK_block_ptr, (start_n * BLOCK_M + K_offset, 0))
|
||||
DV_block_ptr = tl.advance(DV_block_ptr, (start_n * BLOCK_M + V_offset, 0))
|
||||
|
||||
# initialize row/col offsets
|
||||
offs_qm = lo + tl.arange(0, BLOCK_M)
|
||||
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_m = tl.arange(0, BLOCK_N)
|
||||
offs_k = tl.arange(0, BLOCK_DMODEL)
|
||||
# initialize pointers to value-like data
|
||||
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
||||
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk)
|
||||
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
# pointer to row-wise quantities in value-like data
|
||||
D_ptrs = D + off_hz * N_CTX
|
||||
l_ptrs = L + off_hz * N_CTX
|
||||
@@ -169,17 +178,17 @@ def _bwd_kernel_one_col_block(
|
||||
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# k and v stay in SRAM throughout
|
||||
k = tl.load(k_ptrs)
|
||||
v = tl.load(v_ptrs)
|
||||
k = tl.load(K_block_ptr)
|
||||
v = tl.load(V_block_ptr)
|
||||
# loop over rows
|
||||
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
|
||||
offs_m_curr = start_m + offs_m
|
||||
# load q, k, v, do on-chip
|
||||
q = tl.load(q_ptrs)
|
||||
q = tl.load(Q_block_ptr)
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
# NOTE: `do` is pre-divided by `l`; no normalization here
|
||||
if CAUSAL:
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf"))
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.0), float("-inf"))
|
||||
else:
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, tl.trans(k))
|
||||
@@ -187,7 +196,7 @@ def _bwd_kernel_one_col_block(
|
||||
l_i = tl.load(l_ptrs + offs_m_curr)
|
||||
p = tl.math.exp2(qk - l_i[:, None])
|
||||
# compute dv
|
||||
do = tl.load(do_ptrs)
|
||||
do = tl.load(DO_block_ptr)
|
||||
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do, allow_tf32=True)
|
||||
# compute dp = dot(v, do)
|
||||
Di = tl.load(D_ptrs + offs_m_curr)
|
||||
@@ -199,97 +208,156 @@ def _bwd_kernel_one_col_block(
|
||||
dk += tl.dot(tl.trans(ds), q, allow_tf32=True)
|
||||
# compute dq
|
||||
if not SEQUENCE_PARALLEL:
|
||||
dq = tl.load(dq_ptrs)
|
||||
dq = tl.load(DQ_block_ptr)
|
||||
dq += tl.dot(ds, k, allow_tf32=True)
|
||||
tl.store(dq_ptrs, dq)
|
||||
tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty))
|
||||
elif SEQUENCE_PARALLEL:
|
||||
if MMA_V3:
|
||||
dq = tl.dot(ds, k, allow_tf32=True)
|
||||
else:
|
||||
# not work with mma v3, becuase M % 64 != 0
|
||||
dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds), allow_tf32=True))
|
||||
tl.store(dq_ptrs, dq)
|
||||
tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty))
|
||||
|
||||
# increment pointers
|
||||
dq_ptrs += BLOCK_M * stride_qm
|
||||
q_ptrs += BLOCK_M * stride_qm
|
||||
do_ptrs += BLOCK_M * stride_qm
|
||||
DQ_block_ptr = tl.advance(DQ_block_ptr, (BLOCK_M, 0))
|
||||
Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_M, 0))
|
||||
DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_M, 0))
|
||||
# write-back
|
||||
dv_ptrs = DV + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk)
|
||||
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
||||
tl.store(dv_ptrs, dv)
|
||||
tl.store(dk_ptrs, dk)
|
||||
tl.store(DV_block_ptr, dv.to(V.dtype.element_ty))
|
||||
tl.store(DK_block_ptr, dk.to(K.dtype.element_ty))
|
||||
|
||||
|
||||
@jit
|
||||
def _bwd_kernel(
|
||||
# fmt: off
|
||||
Q, K, V, sm_scale,
|
||||
Out, DO,
|
||||
DQ, DK, DV,
|
||||
L,
|
||||
D,
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
Z, H, N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
SEQUENCE_PARALLEL: tl.constexpr,
|
||||
CAUSAL: tl.constexpr,
|
||||
MMA_V3: tl.constexpr
|
||||
# fmt: on
|
||||
):
|
||||
def _bwd_kernel(Q, K, V, sm_scale, #
|
||||
Out, DO, #
|
||||
DQ, DK, DV, #
|
||||
L, #
|
||||
D, #
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vn, stride_vk, #
|
||||
Z, H, N_CTX, #
|
||||
Z_H_N_CTX, #
|
||||
SQ_Z_H_N_CTX, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
|
||||
BLOCK_N: tl.constexpr, #
|
||||
SEQUENCE_PARALLEL: tl.constexpr, #
|
||||
CAUSAL: tl.constexpr, #
|
||||
MMA_V3: tl.constexpr #
|
||||
):
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
off_hz = tl.program_id(0)
|
||||
off_z = off_hz // H
|
||||
off_h = off_hz % H
|
||||
# offset pointers for batch/head
|
||||
Q += off_z * stride_qz + off_h * stride_qh
|
||||
K += off_z * stride_kz + off_h * stride_kh
|
||||
V += off_z * stride_vz + off_h * stride_vh
|
||||
DO += off_z * stride_qz + off_h * stride_qh
|
||||
DQ += off_z * stride_qz + off_h * stride_qh
|
||||
DK += off_z * stride_kz + off_h * stride_kh
|
||||
DV += off_z * stride_vz + off_h * stride_vh
|
||||
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q,
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K,
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_kn, stride_kk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V,
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
DO_block_ptr = tl.make_block_ptr(
|
||||
base=DO,
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
if SEQUENCE_PARALLEL:
|
||||
DQ_block_ptr = tl.make_block_ptr(
|
||||
base=DQ,
|
||||
shape=(SQ_Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
else:
|
||||
DQ_block_ptr = tl.make_block_ptr(
|
||||
base=DQ,
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
|
||||
DK_block_ptr = tl.make_block_ptr(
|
||||
base=DK,
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_kn, stride_kk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
DV_block_ptr = tl.make_block_ptr(
|
||||
base=DV,
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
|
||||
num_block_n = tl.cdiv(N_CTX, BLOCK_N)
|
||||
if not SEQUENCE_PARALLEL:
|
||||
for start_n in range(0, num_block_n):
|
||||
_bwd_kernel_one_col_block(
|
||||
Q, K, V, sm_scale, qk_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
L,
|
||||
D,
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
Z, H, N_CTX,
|
||||
off_hz, start_n, num_block_n,
|
||||
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_N=BLOCK_N,
|
||||
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
|
||||
CAUSAL=CAUSAL,
|
||||
MMA_V3=MMA_V3
|
||||
)
|
||||
_bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, #
|
||||
DQ, DK, DV, #
|
||||
L, #
|
||||
D, #
|
||||
Q_block_ptr, K_block_ptr, V_block_ptr, #
|
||||
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vn, stride_vk, #
|
||||
Z, H, N_CTX, #
|
||||
off_h, off_z, off_hz, start_n, num_block_n, #
|
||||
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, #
|
||||
BLOCK_N=BLOCK_N, #
|
||||
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, #
|
||||
CAUSAL=CAUSAL, #
|
||||
MMA_V3=MMA_V3 #
|
||||
)
|
||||
else:
|
||||
start_n = tl.program_id(1)
|
||||
_bwd_kernel_one_col_block(
|
||||
Q, K, V, sm_scale, qk_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
L,
|
||||
D,
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
Z, H, N_CTX,
|
||||
off_hz, start_n, num_block_n,
|
||||
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_N=BLOCK_N,
|
||||
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
|
||||
CAUSAL=CAUSAL,
|
||||
MMA_V3=MMA_V3
|
||||
)
|
||||
_bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, #
|
||||
DQ, DK, DV, #
|
||||
L, #
|
||||
D, #
|
||||
Q_block_ptr, K_block_ptr, V_block_ptr, #
|
||||
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vn, stride_vk, #
|
||||
Z, H, N_CTX, #
|
||||
off_h, off_z, off_hz, start_n, num_block_n, #
|
||||
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, #
|
||||
BLOCK_N=BLOCK_N, #
|
||||
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, #
|
||||
CAUSAL=CAUSAL, #
|
||||
MMA_V3=MMA_V3 #
|
||||
)
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
@@ -315,19 +383,20 @@ class _attention(torch.autograd.Function):
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
_fwd_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
L,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
q.shape[0] * q.shape[1] * q.shape[2],
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
|
||||
IS_CAUSAL=causal,
|
||||
num_warps=num_warps,
|
||||
num_stages=4)
|
||||
q, k, v, sm_scale, #
|
||||
L, #
|
||||
o, #
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
|
||||
q.shape[0], q.shape[1], q.shape[2], #
|
||||
q.shape[0] * q.shape[1] * q.shape[2], #
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, #
|
||||
IS_CAUSAL=causal, #
|
||||
num_warps=num_warps, #
|
||||
num_stages=4 #
|
||||
)
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, L)
|
||||
ctx.grid = grid
|
||||
@@ -348,35 +417,39 @@ class _attention(torch.autograd.Function):
|
||||
do = do.contiguous()
|
||||
if sequence_parallel:
|
||||
replicas = cdiv(seq_len_kv, BLOCK)
|
||||
new_dq_shape = (replicas,) + q.shape
|
||||
new_dq_shape = (replicas, ) + q.shape
|
||||
dq = torch.zeros(new_dq_shape, device=q.device, dtype=q.dtype)
|
||||
else:
|
||||
dq = torch.zeros_like(q, dtype=torch.float32)
|
||||
dq = torch.zeros_like(q, dtype=q.dtype)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
delta = torch.empty_like(L)
|
||||
_bwd_preprocess[(cdiv(q.shape[2], BLOCK) * ctx.grid[1], )](
|
||||
o, do,
|
||||
o,
|
||||
do,
|
||||
delta,
|
||||
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
BLOCK_M=BLOCK,
|
||||
D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
_bwd_kernel[(ctx.grid[1], cdiv(seq_len_kv, BLOCK) if sequence_parallel else 1)](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do,
|
||||
dq, dk, dv,
|
||||
L,
|
||||
delta,
|
||||
o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
|
||||
SEQUENCE_PARALLEL=sequence_parallel,
|
||||
CAUSAL=ctx.causal,
|
||||
MMA_V3=MMA_V3,
|
||||
num_warps=8,
|
||||
num_stages=1,
|
||||
q, k, v, ctx.sm_scale, #
|
||||
o, do, #
|
||||
dq, dk, dv, #
|
||||
L, #
|
||||
delta, #
|
||||
o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
|
||||
q.shape[0], q.shape[1], q.shape[2], #
|
||||
q.shape[0] * q.shape[1] * q.shape[2], #
|
||||
cdiv(seq_len_kv, BLOCK) * q.shape[0] * q.shape[1] * q.shape[2], #
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK, #
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, #
|
||||
SEQUENCE_PARALLEL=sequence_parallel, #
|
||||
CAUSAL=ctx.causal, #
|
||||
MMA_V3=MMA_V3, #
|
||||
num_warps=8, #
|
||||
num_stages=1 #
|
||||
)
|
||||
|
||||
if len(dq.shape) == 5:
|
||||
|
||||
@@ -37,8 +37,9 @@ def get_configs_io_bound():
|
||||
num_stages=num_stages, num_warps=num_warps))
|
||||
# split_k
|
||||
for split_k in [2, 4, 8, 16]:
|
||||
configs.append(Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
configs.append(
|
||||
Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
return configs
|
||||
|
||||
|
||||
@@ -69,22 +70,22 @@ def get_configs_io_bound():
|
||||
prune_configs_by={
|
||||
'early_config_prune': early_config_prune,
|
||||
'perf_model': estimate_matmul_time,
|
||||
'top_k': 10
|
||||
'top_k': 10,
|
||||
},
|
||||
)
|
||||
@heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||
})
|
||||
@jit
|
||||
def _kernel(A, B, C, M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
dot_out_dtype: tl.constexpr,
|
||||
allow_tf32: tl.constexpr,
|
||||
fp8_fast_accum: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr
|
||||
def _kernel(A, B, C, M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
dot_out_dtype: tl.constexpr, #
|
||||
allow_tf32: tl.constexpr, #
|
||||
fp8_fast_accum: tl.constexpr, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr #
|
||||
):
|
||||
# matrix multiplication
|
||||
pid = tl.program_id(0)
|
||||
@@ -184,14 +185,15 @@ class _matmul(torch.autograd.Function):
|
||||
ab_dtype = False
|
||||
# launch kernel
|
||||
grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||
_kernel[grid](a, b, c, M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
dot_out_dtype=dot_out_dtype,
|
||||
allow_tf32=allow_tf32,
|
||||
fp8_fast_accum=fp8_fast_accum,
|
||||
GROUP_M=8, AB_DTYPE=ab_dtype)
|
||||
_kernel[grid](
|
||||
a, b, c, M, N, K, #
|
||||
a.stride(0), a.stride(1), #
|
||||
b.stride(0), b.stride(1), #
|
||||
c.stride(0), c.stride(1), #
|
||||
dot_out_dtype=dot_out_dtype, #
|
||||
allow_tf32=allow_tf32, #
|
||||
fp8_fast_accum=fp8_fast_accum, #
|
||||
GROUP_M=8, AB_DTYPE=ab_dtype)
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -5,8 +5,7 @@ import torch
|
||||
from .. import cdiv
|
||||
from .._C.libtriton.triton import runtime
|
||||
from ..runtime import driver
|
||||
from ..testing import (get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops,
|
||||
nvsmi)
|
||||
from ..testing import (get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops, nvsmi)
|
||||
|
||||
|
||||
def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
@@ -14,7 +13,8 @@ def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
total_warps = num_ctas * min(num_warps, 4)
|
||||
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
|
||||
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
|
||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, cur_sm_clock, backend, device)
|
||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(
|
||||
dtype, cur_sm_clock, backend, device)
|
||||
return tflops
|
||||
|
||||
|
||||
@@ -35,12 +35,12 @@ def get_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
|
||||
|
||||
def estimate_matmul_time(
|
||||
# backend, device,
|
||||
num_warps, num_stages,
|
||||
A, B, C,
|
||||
M, N, K,
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K,
|
||||
debug=False, **kwargs
|
||||
# backend, device,
|
||||
num_warps, num_stages, #
|
||||
A, B, C, #
|
||||
M, N, K, #
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, #
|
||||
debug=False, **kwargs #
|
||||
):
|
||||
''' return estimated running time in ms
|
||||
= max(compute, loading) + store '''
|
||||
@@ -149,8 +149,9 @@ def early_config_prune(configs, named_args):
|
||||
optimal_num_stages = ldgsts_latency / mma_cycles
|
||||
|
||||
# nearest stages, prefer large #stages
|
||||
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages)
|
||||
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
|
||||
nearest = heapq.nsmallest(
|
||||
2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages)
|
||||
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
|
||||
|
||||
for n in nearest:
|
||||
pruned_configs.append(n[0])
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune,
|
||||
heuristics)
|
||||
from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune, heuristics)
|
||||
from .driver import driver
|
||||
from .jit import (JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret,
|
||||
version_key)
|
||||
from .jit import JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret
|
||||
|
||||
__all__ = [
|
||||
"driver",
|
||||
@@ -12,7 +10,6 @@ __all__ = [
|
||||
"heuristics",
|
||||
"JITFunction",
|
||||
"KernelInterface",
|
||||
"version_key",
|
||||
"reinterpret",
|
||||
"TensorWrapper",
|
||||
"OutOfResources",
|
||||
|
||||
@@ -9,11 +9,10 @@ from .jit import KernelInterface
|
||||
|
||||
|
||||
class OutOfResources(Exception):
|
||||
|
||||
def __init__(self, required, limit, name):
|
||||
self.message = f'out of resource: {name}, '\
|
||||
f'Required: {required}, '\
|
||||
f'Hardware limit: {limit}'
|
||||
self.message += '. Reducing block sizes or `num_stages` may help.'
|
||||
self.message = (f"out of resource: {name}, Required: {required}, Hardware limit: {limit}. " +
|
||||
"Reducing block sizes or `num_stages` may help.")
|
||||
self.required = required
|
||||
self.limit = limit
|
||||
self.name = name
|
||||
@@ -25,38 +24,73 @@ class OutOfResources(Exception):
|
||||
|
||||
|
||||
class Autotuner(KernelInterface):
|
||||
def __init__(self, fn, arg_names, configs, key, verbose, reset_to_zero, prune_configs_by: Dict = None, warmup=25, rep=100):
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fn,
|
||||
arg_names,
|
||||
configs,
|
||||
key,
|
||||
verbose,
|
||||
reset_to_zero,
|
||||
restore_value,
|
||||
prune_configs_by: Dict = None,
|
||||
warmup=25,
|
||||
rep=100,
|
||||
):
|
||||
"""
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
|
||||
'''
|
||||
"""
|
||||
if not configs:
|
||||
self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)]
|
||||
else:
|
||||
self.configs = configs
|
||||
self.key_idx = [arg_names.index(k) for k in key]
|
||||
self.cache = {}
|
||||
# hook to reset all required tensor to zeros before relaunching a kernel
|
||||
self.hook = lambda args: 0
|
||||
self.arg_names = arg_names
|
||||
|
||||
# Reset to zero or restore values
|
||||
self.reset_idx = []
|
||||
if reset_to_zero is not None:
|
||||
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
|
||||
self.restore_idx = []
|
||||
if restore_value is not None:
|
||||
self.restore_idx = [arg_names.index(k) for k in restore_value]
|
||||
|
||||
def _hook(args):
|
||||
# Hook to reset or restore for required tensors
|
||||
self.pre_hook = lambda args, reset_only=False: 0
|
||||
self.post_hook = lambda args: 0
|
||||
if len(self.reset_idx) > 0 or len(self.restore_idx) > 0:
|
||||
|
||||
def _pre_hook(args, reset_only=False):
|
||||
for i in self.reset_idx:
|
||||
args[i].zero_()
|
||||
self.hook = _hook
|
||||
self.arg_names = arg_names
|
||||
# prune configs
|
||||
if not reset_only:
|
||||
self.restore_copies = [args[i].clone() for i in self.restore_idx]
|
||||
|
||||
self.pre_hook = _pre_hook
|
||||
if len(self.restore_idx) > 0:
|
||||
|
||||
def _post_hook(args):
|
||||
for i, j in enumerate(self.restore_idx):
|
||||
args[j].copy_(self.restore_copies[i])
|
||||
self.restore_copies = []
|
||||
|
||||
self.post_hook = _post_hook
|
||||
|
||||
# Prune configs
|
||||
if prune_configs_by:
|
||||
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
|
||||
if 'early_config_prune' in prune_configs_by:
|
||||
early_config_prune = prune_configs_by['early_config_prune']
|
||||
perf_model, top_k = prune_configs_by["perf_model"], prune_configs_by["top_k"]
|
||||
if "early_config_prune" in prune_configs_by:
|
||||
early_config_prune = prune_configs_by["early_config_prune"]
|
||||
else:
|
||||
perf_model, top_k, early_config_prune = None, None, None
|
||||
self.perf_model, self.configs_top_k = perf_model, top_k
|
||||
self.early_config_prune = early_config_prune
|
||||
|
||||
self.fn = fn
|
||||
self.warmup = warmup
|
||||
self.rep = rep
|
||||
@@ -67,10 +101,8 @@ class Autotuner(KernelInterface):
|
||||
# as kwargs and by the autotuner
|
||||
conflicts = meta.keys() & config.kwargs.keys()
|
||||
if conflicts:
|
||||
raise ValueError(
|
||||
f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
||||
" Make sure that you don't re-define auto-tuned symbols."
|
||||
)
|
||||
raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
||||
" Make sure that you don't re-define auto-tuned symbols.")
|
||||
# augment meta-parameters with tunable ones
|
||||
current = dict(meta, **config.kwargs)
|
||||
full_nargs = {**self.nargs, **current}
|
||||
@@ -78,16 +110,22 @@ class Autotuner(KernelInterface):
|
||||
def kernel_call():
|
||||
if config.pre_hook:
|
||||
config.pre_hook(full_nargs)
|
||||
self.hook(args)
|
||||
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization,
|
||||
# enable_persistent=False,
|
||||
**current)
|
||||
self.pre_hook(args)
|
||||
self.fn.run(
|
||||
*args,
|
||||
num_warps=config.num_warps,
|
||||
num_stages=config.num_stages,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization,
|
||||
# enable_persistent=False,
|
||||
**current,
|
||||
)
|
||||
self.post_hook(args)
|
||||
|
||||
try:
|
||||
return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
|
||||
except OutOfResources:
|
||||
return [float('inf'), float('inf'), float('inf')]
|
||||
return [float("inf"), float("inf"), float("inf")]
|
||||
|
||||
def get_best_config(self):
|
||||
return self.best_config
|
||||
@@ -110,12 +148,11 @@ class Autotuner(KernelInterface):
|
||||
# prune configs
|
||||
pruned_configs = self.prune_configs(kwargs)
|
||||
bench_start = time.time()
|
||||
timings = {config: self._bench(*args, config=config, **kwargs)
|
||||
for config in pruned_configs}
|
||||
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
|
||||
bench_end = time.time()
|
||||
self.bench_time = bench_end - bench_start
|
||||
self.cache[key] = builtins.min(timings, key=timings.get)
|
||||
self.hook(args)
|
||||
self.pre_hook(args, reset_only=True)
|
||||
self.configs_timings = timings
|
||||
if self.verbose:
|
||||
print(str(key) + ": " + str(self.cache[key]))
|
||||
@@ -126,9 +163,15 @@ class Autotuner(KernelInterface):
|
||||
full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs}
|
||||
if config.pre_hook is not None:
|
||||
config.pre_hook(full_nargs)
|
||||
ret = self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization, **kwargs, **config.kwargs)
|
||||
ret = self.fn.run(
|
||||
*args,
|
||||
num_warps=config.num_warps,
|
||||
num_stages=config.num_stages,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization,
|
||||
**kwargs,
|
||||
**config.kwargs,
|
||||
)
|
||||
self.nargs = None
|
||||
return ret
|
||||
|
||||
@@ -142,17 +185,20 @@ class Autotuner(KernelInterface):
|
||||
top_k = int(len(self.configs) * top_k)
|
||||
if len(pruned_configs) > top_k:
|
||||
est_timing = {
|
||||
config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages,
|
||||
num_warps=config.num_warps,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization,
|
||||
enable_persistent=config.enable_persistent)
|
||||
config:
|
||||
self.perf_model(
|
||||
**self.nargs,
|
||||
**kwargs,
|
||||
**config.kwargs,
|
||||
num_stages=config.num_stages,
|
||||
num_warps=config.num_warps,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization,
|
||||
enable_persistent=config.enable_persistent,
|
||||
)
|
||||
for config in pruned_configs
|
||||
}
|
||||
pruned_configs = sorted(
|
||||
est_timing.keys(),
|
||||
key=lambda x: est_timing[x])[
|
||||
:top_k]
|
||||
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
||||
return pruned_configs
|
||||
|
||||
def warmup(self, *args, **kwargs):
|
||||
@@ -195,7 +241,7 @@ class Config:
|
||||
self.num_ctas = num_ctas
|
||||
self.num_stages = num_stages
|
||||
self.enable_warp_specialization = enable_warp_specialization
|
||||
# TODO[shuhaoj]: May make enable_persistent configurable in future if necessay.
|
||||
# TODO[shuhaoj]: May make enable_persistent configurable in future if necessary.
|
||||
self.enable_persistent = False
|
||||
self.pre_hook = pre_hook
|
||||
|
||||
@@ -207,13 +253,12 @@ class Config:
|
||||
## Comment out Hopper specific parameters
|
||||
#res.append(f'num_ctas: {self.num_ctas}')
|
||||
res.append(f'num_stages: {self.num_stages}')
|
||||
#res.append(
|
||||
# f'enable_warp_specialization: {self.enable_warp_specialization}')
|
||||
#res.append(f'enable_warp_specialization: {self.enable_warp_specialization}')
|
||||
#res.append(f'enable_persistent: {self.enable_persistent}')
|
||||
return ', '.join(res)
|
||||
|
||||
|
||||
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=False, warmup=25, rep=100):
|
||||
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, verbose=False, warmup=25, rep=100):
|
||||
"""
|
||||
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
||||
|
||||
@@ -244,6 +289,8 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=Fa
|
||||
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
|
||||
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
||||
:type reset_to_zero: list[str]
|
||||
:param restore_value: a list of argument names whose value will be restored after evaluating any configs.
|
||||
:type restore_value: list[str]
|
||||
:param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25.
|
||||
:type warmup: int
|
||||
:param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100.
|
||||
@@ -251,8 +298,9 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=Fa
|
||||
:param verbose: a boolean that controls whether the best_config for each key is printed
|
||||
:type verbose: bool
|
||||
"""
|
||||
|
||||
def decorator(fn):
|
||||
return Autotuner(fn, fn.arg_names, configs, key, verbose, reset_to_zero, prune_configs_by, warmup, rep)
|
||||
return Autotuner(fn, fn.arg_names, configs, key, verbose, reset_to_zero, restore_value, prune_configs_by, warmup, rep)
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -286,6 +334,7 @@ def heuristics(values):
|
||||
each such function takes a list of positional arguments as input.
|
||||
:type values: dict[str, Callable[[list[Any]], Any]]
|
||||
"""
|
||||
|
||||
def decorator(fn):
|
||||
return Heuristics(fn, fn.arg_names, values)
|
||||
|
||||
|
||||
@@ -1,27 +1,42 @@
|
||||
#include "cuda.h"
|
||||
#include <dlfcn.h>
|
||||
#include <stdbool.h>
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
|
||||
static inline void gpuAssert(CUresult code, const char *file, int line) {
|
||||
if (code != CUDA_SUCCESS) {
|
||||
const char *prefix = "Triton Error [CUDA]: ";
|
||||
const char *str;
|
||||
cuGetErrorString(code, &str);
|
||||
char err[1024] = {0};
|
||||
strcat(err, prefix);
|
||||
strcat(err, str);
|
||||
PyGILState_STATE gil_state;
|
||||
gil_state = PyGILState_Ensure();
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
PyGILState_Release(gil_state);
|
||||
}
|
||||
// Raises a Python exception and returns false if code is not CUDA_SUCCESS.
|
||||
static bool gpuAssert(CUresult code, const char *file, int line) {
|
||||
if (code == CUDA_SUCCESS)
|
||||
return true;
|
||||
|
||||
const char *prefix = "Triton Error [CUDA]: ";
|
||||
const char *str;
|
||||
cuGetErrorString(code, &str);
|
||||
char err[1024] = {0};
|
||||
strcat(err, prefix);
|
||||
strcat(err, str);
|
||||
PyGILState_STATE gil_state;
|
||||
gil_state = PyGILState_Ensure();
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
PyGILState_Release(gil_state);
|
||||
return false;
|
||||
}
|
||||
|
||||
#define CUDA_CHECK(ans) \
|
||||
{ \
|
||||
{ gpuAssert((ans), __FILE__, __LINE__); } \
|
||||
}
|
||||
// To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block.
|
||||
#define CUDA_CHECK_AND_RETURN_NULL(ans) \
|
||||
do { \
|
||||
if (!gpuAssert((ans), __FILE__, __LINE__)) \
|
||||
return NULL; \
|
||||
} while (0)
|
||||
|
||||
// To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block.
|
||||
#define CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(ans) \
|
||||
do { \
|
||||
if (!gpuAssert((ans), __FILE__, __LINE__)) { \
|
||||
PyEval_RestoreThread(_save); \
|
||||
return NULL; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define ADD_ENUM_ITEM(value) \
|
||||
do { \
|
||||
@@ -200,16 +215,16 @@ static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
|
||||
int sm_clock_rate;
|
||||
int mem_clock_rate;
|
||||
int mem_bus_width;
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
|
||||
&max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
||||
device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
|
||||
&multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate,
|
||||
CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
|
||||
&sm_clock_rate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
|
||||
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
|
||||
&mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
|
||||
&mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device));
|
||||
|
||||
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem",
|
||||
@@ -237,33 +252,37 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
|
||||
CUcontext pctx = 0;
|
||||
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
CUDA_CHECK(cuCtxGetCurrent(&pctx));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&pctx));
|
||||
if (!pctx) {
|
||||
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
|
||||
CUDA_CHECK(cuCtxSetCurrent(pctx));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
||||
cuDevicePrimaryCtxRetain(&pctx, device));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(pctx));
|
||||
}
|
||||
|
||||
CUDA_CHECK(cuModuleLoadData(&mod, data));
|
||||
CUDA_CHECK(cuModuleGetFunction(&fun, mod, name));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuModuleLoadData(&mod, data));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
||||
cuModuleGetFunction(&fun, mod, name));
|
||||
// get allocated registers and spilled registers from the function
|
||||
CUDA_CHECK(cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
|
||||
CUDA_CHECK(
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
||||
cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
||||
cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
|
||||
n_spills /= 4;
|
||||
// set dynamic shared memory if necessary
|
||||
int shared_optin;
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute(
|
||||
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
||||
device));
|
||||
if (shared > 49152 && shared_optin > 49152) {
|
||||
CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
||||
cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
|
||||
int shared_total, shared_static;
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute(
|
||||
&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
|
||||
device));
|
||||
CUDA_CHECK(cuFuncGetAttribute(&shared_static,
|
||||
CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
|
||||
CUDA_CHECK(
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute(
|
||||
&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
||||
cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
shared_optin - shared_static));
|
||||
}
|
||||
@@ -286,7 +305,7 @@ static PyObject *memAlloc(PyObject *self, PyObject *args) {
|
||||
}
|
||||
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
CUDA_CHECK(cuMemAlloc(&dptr, bytesize));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuMemAlloc(&dptr, bytesize));
|
||||
Py_END_ALLOW_THREADS;
|
||||
|
||||
return PyLong_FromUnsignedLongLong((unsigned long long)dptr);
|
||||
@@ -307,7 +326,8 @@ static PyObject *memcpyHtoD(PyObject *self, PyObject *args) {
|
||||
srcHost = (const void *)srcHostPtr;
|
||||
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
CUDA_CHECK(cuMemcpyHtoD(dstDevice, srcHost, byteCount));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
||||
cuMemcpyHtoD(dstDevice, srcHost, byteCount));
|
||||
Py_END_ALLOW_THREADS;
|
||||
|
||||
Py_RETURN_NONE;
|
||||
@@ -321,7 +341,7 @@ static PyObject *memFree(PyObject *self, PyObject *args) {
|
||||
}
|
||||
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
CUDA_CHECK(cuMemFree(dptr));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuMemFree(dptr));
|
||||
Py_END_ALLOW_THREADS;
|
||||
|
||||
Py_RETURN_NONE;
|
||||
@@ -411,7 +431,7 @@ static PyObject *tensorMapEncodeTiled(PyObject *self, PyObject *args) {
|
||||
}
|
||||
// Call the function
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
CUDA_CHECK(cuTensorMapEncodeTiledHandle(
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuTensorMapEncodeTiledHandle(
|
||||
tensorMap, tensorDataType, tensorRank, globalAddress, globalDim,
|
||||
globalStrides, boxDim, elementStrides, interleave, swizzle, l2Promotion,
|
||||
oobFill));
|
||||
|
||||
@@ -19,6 +19,7 @@ def default_dump_dir():
|
||||
|
||||
|
||||
class CacheManager(ABC):
|
||||
|
||||
def __init__(self, key):
|
||||
pass
|
||||
|
||||
@@ -44,20 +45,21 @@ class CacheManager(ABC):
|
||||
|
||||
|
||||
class FileCacheManager(CacheManager):
|
||||
|
||||
def __init__(self, key, override=False, dump=False):
|
||||
self.key = key
|
||||
self.lock_path = None
|
||||
if (dump):
|
||||
if dump:
|
||||
self.cache_dir = default_dump_dir()
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
elif (override):
|
||||
elif override:
|
||||
self.cache_dir = default_override_dir()
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
else:
|
||||
# create cache directory if it doesn't exist
|
||||
self.cache_dir = os.getenv('TRITON_CACHE_DIR', "").strip() or default_cache_dir()
|
||||
self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
|
||||
if self.cache_dir:
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||
@@ -93,9 +95,8 @@ class FileCacheManager(CacheManager):
|
||||
result = {}
|
||||
for c in child_paths:
|
||||
p = self._make_path(c)
|
||||
if not os.path.exists(p):
|
||||
raise Exception(f"Group file {p} does not exist from group {grp_filename} ")
|
||||
result[c] = p
|
||||
if os.path.exists(p):
|
||||
result[c] = p
|
||||
return result
|
||||
|
||||
# Note a group of pushed files as being part of a group
|
||||
@@ -142,6 +143,7 @@ def get_cache_manager(key) -> CacheManager:
|
||||
|
||||
if user_cache_manager is not None and user_cache_manager != __cache_cls_nme:
|
||||
import importlib
|
||||
|
||||
module_path, clz_nme = user_cache_manager.split(":")
|
||||
module = importlib.import_module(module_path)
|
||||
__cache_cls = getattr(module, clz_nme)
|
||||
|
||||
@@ -9,7 +9,6 @@ from .cache import get_cache_manager
|
||||
|
||||
|
||||
class DriverBase(metaclass=abc.ABCMeta):
|
||||
|
||||
CUDA = 0
|
||||
HIP = 1
|
||||
|
||||
@@ -19,6 +18,8 @@ class DriverBase(metaclass=abc.ABCMeta):
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# CUDA
|
||||
# -----------------------------
|
||||
@@ -27,7 +28,7 @@ class DriverBase(metaclass=abc.ABCMeta):
|
||||
class CudaUtils(object):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(CudaUtils, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
@@ -47,6 +48,7 @@ class CudaUtils(object):
|
||||
with open(so, "rb") as f:
|
||||
cache_path = cache.put(f.read(), fname, binary=True)
|
||||
import importlib.util
|
||||
|
||||
spec = importlib.util.spec_from_file_location("cuda_utils", cache_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
@@ -66,7 +68,7 @@ class CudaUtils(object):
|
||||
class CudaDriver(DriverBase):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(CudaDriver, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
@@ -74,14 +76,16 @@ class CudaDriver(DriverBase):
|
||||
self.utils = CudaUtils()
|
||||
self.backend = self.CUDA
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# HIP
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class HIPUtils(object):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(HIPUtils, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
@@ -101,6 +105,7 @@ class HIPUtils(object):
|
||||
with open(so, "rb") as f:
|
||||
cache_path = cache.put(f.read(), fname, binary=True)
|
||||
import importlib.util
|
||||
|
||||
spec = importlib.util.spec_from_file_location("hip_utils", cache_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
@@ -111,7 +116,7 @@ class HIPUtils(object):
|
||||
class HIPDriver(DriverBase):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(HIPDriver, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
@@ -123,7 +128,7 @@ class HIPDriver(DriverBase):
|
||||
class UnsupportedDriver(DriverBase):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(UnsupportedDriver, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
@@ -131,12 +136,14 @@ class UnsupportedDriver(DriverBase):
|
||||
self.utils = None
|
||||
self.backend = None
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Driver
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class LazyProxy:
|
||||
|
||||
def __init__(self, init_fn):
|
||||
self._init_fn = init_fn
|
||||
self._obj = None
|
||||
@@ -150,7 +157,7 @@ class LazyProxy:
|
||||
return getattr(self._obj, name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name in ['_init_fn', '_obj']:
|
||||
if name in ["_init_fn", "_obj"]:
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
self._initialize_obj()
|
||||
@@ -172,6 +179,7 @@ class LazyProxy:
|
||||
|
||||
def initialize_driver():
|
||||
import torch
|
||||
|
||||
if torch.version.hip is not None:
|
||||
return HIPDriver()
|
||||
elif torch.cuda.is_available():
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
|
||||
class OutOfResources(Exception):
|
||||
|
||||
def __init__(self, required, limit, name):
|
||||
self.message = f'out of resource: {name}, '\
|
||||
f'Required: {required}, '\
|
||||
f'Hardware limit: {limit}'
|
||||
self.message += '. Reducing block sizes or `num_stages` may help.'
|
||||
self.message = f"out of resource: {name}, " f"Required: {required}, " f"Hardware limit: {limit}"
|
||||
self.message += ". Reducing block sizes or `num_stages` may help."
|
||||
self.required = required
|
||||
self.limit = limit
|
||||
self.name = name
|
||||
|
||||
@@ -74,11 +74,15 @@ class BlockPointerHandle:
|
||||
|
||||
|
||||
def wrap_ret(compute_ret_ty):
|
||||
|
||||
def wrapper(fn):
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
ret = fn(*args, **kwargs)
|
||||
return TensorHandle(ret.data, compute_ret_ty(*args, **kwargs))
|
||||
|
||||
return wrapped
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@@ -249,11 +253,13 @@ class Builder:
|
||||
# ternary functions
|
||||
def ternary_op(self, lhs, rhs, other, op):
|
||||
return TensorHandle(op(lhs.data, rhs.data, other.data), other.dtype)
|
||||
|
||||
create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where)
|
||||
|
||||
# unary functions
|
||||
def unary_op(self, arg, op):
|
||||
return TensorHandle(op(arg.data), arg.dtype)
|
||||
|
||||
create_exp = lambda self, arg: self.unary_op(arg, np.exp)
|
||||
create_cos = lambda self, arg: self.unary_op(arg, np.cos)
|
||||
create_sin = lambda self, arg: self.unary_op(arg, np.sin)
|
||||
@@ -279,7 +285,8 @@ class Builder:
|
||||
dtype_tt = ptr.dtype.element_ty
|
||||
return TensorHandle(ptr.data + (dtype_tt.primitive_bitwidth // 8) * offset.data.astype(np.uint64), ptr.dtype)
|
||||
|
||||
def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, is_volatile):
|
||||
def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy,
|
||||
is_volatile):
|
||||
ptrs, masks = ptr.materialize_pointers(boundary_check)
|
||||
assert padding_option is None
|
||||
other = None
|
||||
@@ -297,6 +304,7 @@ class Builder:
|
||||
|
||||
def create_int_to_ptr(self, val, dst_ty):
|
||||
return TensorHandle(val.data.astype(np.uint64), dst_ty)
|
||||
|
||||
# def create_cat(self, lhs, rhs):
|
||||
# pass
|
||||
|
||||
@@ -360,7 +368,10 @@ class Builder:
|
||||
|
||||
|
||||
def patch_attr(obj, name, member, builder):
|
||||
new_member = lambda *args, member=member, **kwargs: (member(*args, **{k: v for k, v in kwargs.items() if k != '_builder'}, _builder=builder))
|
||||
new_member = lambda *args, member=member, **kwargs: (member(*args, **
|
||||
{k: v
|
||||
for k, v in kwargs.items()
|
||||
if k != "_builder"}, _builder=builder))
|
||||
setattr(obj, name, new_member)
|
||||
|
||||
|
||||
@@ -384,8 +395,8 @@ def _patch_lang_core(lang, builder):
|
||||
def _new_reduce(input, axis, combine_fn):
|
||||
fn = combine_fn.fn.__name__
|
||||
mapping = {
|
||||
'maximum': np.max,
|
||||
'_sum_combine': np.sum,
|
||||
"maximum": np.max,
|
||||
"_sum_combine": np.sum,
|
||||
}
|
||||
ret = mapping[fn](input.handle.data, axis=axis)
|
||||
ret_type = tl.block_type(input.dtype, ret.shape)
|
||||
@@ -397,15 +408,16 @@ def _patch_lang_core(lang, builder):
|
||||
def _patch_lang_math(lang, builder):
|
||||
math = lang.math
|
||||
mapping = {
|
||||
'abs': 'abs',
|
||||
'acos': 'arccos',
|
||||
'asin': 'arcsin',
|
||||
'exp2': 'exp2',
|
||||
'log2': 'log2',
|
||||
'max': 'maximum',
|
||||
"abs": "abs",
|
||||
"acos": "arccos",
|
||||
"asin": "arcsin",
|
||||
"exp2": "exp2",
|
||||
"log2": "log2",
|
||||
"max": "maximum",
|
||||
}
|
||||
|
||||
def make_numpy(name):
|
||||
|
||||
def impl(*args, **kwargs):
|
||||
ret_type = args[0].type # TODO: incorrect
|
||||
ret_dtype = args[0].dtype # TODO: incorrect
|
||||
@@ -414,15 +426,18 @@ def _patch_lang_math(lang, builder):
|
||||
ret = getattr(np, mapping[name])(*args, **kwargs)
|
||||
ret = tl.core.tensor(TensorHandle(ret, ret_dtype), ret_type)
|
||||
return ret
|
||||
|
||||
return impl
|
||||
|
||||
def make_fallback(name):
|
||||
|
||||
def fallback(*args, **kwargs):
|
||||
raise NotImplementedError(f"""
|
||||
{name} not supported in interpreter mode: no known numpy implementation.
|
||||
If you think that {name} in fact does have a numpy implementation, please add it
|
||||
to the mapping in python/triton/interpreter/new_interpreter.py:_patch_lang_math.
|
||||
""")
|
||||
|
||||
return fallback
|
||||
|
||||
for name, member in inspect.getmembers(math):
|
||||
@@ -438,7 +453,7 @@ def _implicit_cvt(arg):
|
||||
ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
|
||||
handle = TensorHandle(np.array([arg], dtype=np.int32), ty)
|
||||
return tl.tensor(handle, ty)
|
||||
if hasattr(arg, 'data_ptr'):
|
||||
if hasattr(arg, "data_ptr"):
|
||||
ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
|
||||
handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty)
|
||||
return tl.tensor(handle, ty)
|
||||
@@ -453,28 +468,29 @@ def _unwrap(tensor):
|
||||
|
||||
builder = Builder()
|
||||
|
||||
RESERVED_KWS = ['num_warps', 'num_stages', 'num_ctas', 'enable_warp_specialization', 'enable_fp_fusion']
|
||||
RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_warp_specialization", "enable_fp_fusion"]
|
||||
|
||||
|
||||
class GridExecutor:
|
||||
|
||||
def __init__(self, fn, arg_names, grid):
|
||||
from .jit import _normalize_ty # TODO: modularize
|
||||
|
||||
self.fn = fn
|
||||
self.arg_names = arg_names
|
||||
self.grid = grid
|
||||
__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
|
||||
self.constexprs = [name for name in arg_names if __annotations__.get(name) == 'constexpr']
|
||||
self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"]
|
||||
|
||||
def _patch_lang(self, builder):
|
||||
lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]]
|
||||
assert len(lang) == 1, "triton.language must be visible from within jit'd function"
|
||||
_patch_lang_tensor(getattr(lang[0], 'tensor'), builder)
|
||||
_patch_lang_tensor(getattr(lang[0], "tensor"), builder)
|
||||
_patch_lang_core(lang[0], builder)
|
||||
_patch_lang_math(lang[0], builder)
|
||||
|
||||
def __call__(self, *args_dev, **kwargs):
|
||||
args_hst = [_unwrap(arg).cpu() if hasattr(arg, 'data_ptr') else arg for arg in args_dev]
|
||||
args_hst = [_unwrap(arg).cpu() if hasattr(arg, "data_ptr") else arg for arg in args_dev]
|
||||
# removes reserved keywords from kwargs
|
||||
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS}
|
||||
# remaps core language functions to interpreted ones
|
||||
@@ -486,7 +502,7 @@ class GridExecutor:
|
||||
# iterate through grid
|
||||
grid = self.grid(args) if callable(self.grid) else self.grid
|
||||
assert len(grid) <= 3
|
||||
grid = grid + (1,) * (3 - len(grid))
|
||||
grid = grid + (1, ) * (3 - len(grid))
|
||||
builder.set_grid_dim(*grid)
|
||||
for x in range(grid[0]):
|
||||
for y in range(grid[1]):
|
||||
@@ -495,7 +511,7 @@ class GridExecutor:
|
||||
self.fn(**args)
|
||||
# copy arguments back to propagate side-effects
|
||||
for arg_dev, arg_hst in zip(args_dev, args_hst):
|
||||
if hasattr(arg_dev, 'data_ptr'):
|
||||
if hasattr(arg_dev, "data_ptr"):
|
||||
_unwrap(arg_dev).copy_(arg_hst.to(arg_dev.device))
|
||||
|
||||
|
||||
@@ -504,17 +520,18 @@ class InterpretedFunction:
|
||||
def _patch_lang(self, builder):
|
||||
lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]]
|
||||
assert len(lang) == 1, "triton.language must be visible from within jit'd function"
|
||||
_patch_lang_tensor(getattr(lang[0], 'tensor'), builder)
|
||||
_patch_lang_tensor(getattr(lang[0], "tensor"), builder)
|
||||
_patch_lang_core(lang[0], builder)
|
||||
|
||||
def __init__(self, fn) -> None:
|
||||
self.fn = fn
|
||||
|
||||
def run(*args, **kwargs):
|
||||
grid = kwargs['grid']
|
||||
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS + ['grid']}
|
||||
grid = kwargs["grid"]
|
||||
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS + ["grid"]}
|
||||
|
||||
return GridExecutor(self.fn, self.arg_names, grid)(*args, **kwargs)
|
||||
|
||||
self.run = run
|
||||
signature = inspect.signature(fn)
|
||||
self.arg_names = [v.name for v in signature.parameters.values()]
|
||||
|
||||
@@ -5,48 +5,48 @@ import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import os
|
||||
import subprocess
|
||||
import textwrap
|
||||
from collections import defaultdict, namedtuple
|
||||
from typing import (Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast,
|
||||
overload)
|
||||
from functools import cached_property
|
||||
from typing import Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast, overload
|
||||
|
||||
from .._C.libtriton.triton import TMAInfos
|
||||
from ..common.backend import get_backend, path_to_ptxas
|
||||
from ..language.core import dtype
|
||||
from ..common.backend import get_backend, get_cuda_version_key
|
||||
from .interpreter import InterpretedFunction
|
||||
|
||||
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
TRITON_VERSION = "2.1.0"
|
||||
|
||||
|
||||
def get_cuda_stream(idx=None):
|
||||
if idx is None:
|
||||
idx = get_current_device()
|
||||
try:
|
||||
from torch._C import _cuda_getCurrentRawStream
|
||||
|
||||
return _cuda_getCurrentRawStream(idx)
|
||||
except ImportError:
|
||||
import torch
|
||||
|
||||
return torch.cuda.current_stream(idx).cuda_stream
|
||||
|
||||
|
||||
def get_current_device():
|
||||
import torch
|
||||
|
||||
return torch.cuda.current_device()
|
||||
|
||||
|
||||
def set_current_device(idx):
|
||||
import torch
|
||||
|
||||
torch.cuda.set_device(idx)
|
||||
|
||||
|
||||
def get_device_capability(idx):
|
||||
import torch
|
||||
|
||||
return torch.cuda.get_device_capability(idx)
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
T = TypeVar("T")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Dependencies Finder
|
||||
@@ -72,7 +72,8 @@ class DependenciesFinder(ast.NodeVisitor):
|
||||
lhs = self.visit(node.value)
|
||||
while isinstance(lhs, ast.Attribute):
|
||||
lhs = self.visit(lhs.value)
|
||||
if lhs is None or (getattr(lhs, "__name__", "") == "triton" or getattr(lhs, "__name__", "").endswith(".triton")):
|
||||
if lhs is None or (getattr(lhs, "__name__", "") == "triton"
|
||||
or getattr(lhs, "__name__", "").endswith(".triton")):
|
||||
return None
|
||||
return getattr(lhs, node.attr)
|
||||
|
||||
@@ -82,55 +83,26 @@ class DependenciesFinder(ast.NodeVisitor):
|
||||
return
|
||||
if inspect.isbuiltin(func):
|
||||
return
|
||||
if func.__module__ and (func.__module__.startswith('triton.') or '.triton.' in func.__module__):
|
||||
if func.__module__ and (func.__module__.startswith("triton.") or ".triton." in func.__module__):
|
||||
return
|
||||
assert isinstance(func, JITFunction), f"Function \"{func.__name__}\" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this"
|
||||
assert isinstance(
|
||||
func, JITFunction
|
||||
), f'Function "{func.__name__}" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this'
|
||||
if func.hash is None:
|
||||
tree = ast.parse(func.src)
|
||||
finder = DependenciesFinder(func.__globals__, func.src)
|
||||
finder.visit(tree)
|
||||
func.hash = finder.ret
|
||||
noinline = str(getattr(func, 'noinline', False))
|
||||
noinline = str(getattr(func, "noinline", False))
|
||||
self.ret = (self.ret + func.hash + noinline).encode("utf-8")
|
||||
self.ret = hashlib.sha1(self.ret).hexdigest()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# JITFunction
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def version_key():
|
||||
import pkgutil
|
||||
contents = []
|
||||
# frontend
|
||||
with open(__file__, "rb") as f:
|
||||
contents += [hashlib.sha1(f.read()).hexdigest()]
|
||||
# compiler
|
||||
compiler_path = os.path.join(TRITON_PATH, 'compiler')
|
||||
for lib in pkgutil.iter_modules([compiler_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.sha1(f.read()).hexdigest()]
|
||||
# backend
|
||||
libtriton_hash = hashlib.sha1()
|
||||
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
|
||||
while True:
|
||||
chunk = f.read(1024 ** 2)
|
||||
if not chunk:
|
||||
break
|
||||
libtriton_hash.update(chunk)
|
||||
contents.append(libtriton_hash.hexdigest())
|
||||
# language
|
||||
language_path = os.path.join(TRITON_PATH, 'language')
|
||||
for lib in pkgutil.iter_modules([language_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.sha1(f.read()).hexdigest()]
|
||||
# ptxas version
|
||||
ptxas = path_to_ptxas()[0]
|
||||
ptxas_version = hashlib.sha1(subprocess.check_output([ptxas, "--version"])).hexdigest()
|
||||
return '-'.join(TRITON_VERSION) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||
|
||||
|
||||
def _normalize_ty(ty) -> str:
|
||||
if isinstance(ty, type):
|
||||
return ty.__name__
|
||||
@@ -139,6 +111,85 @@ def _normalize_ty(ty) -> str:
|
||||
return repr(ty)
|
||||
|
||||
|
||||
class KernelParam:
|
||||
"""Represents a parameter to a @jit'ed function.
|
||||
|
||||
A parameter is just the name plus metadata; a parameter plus a value is a
|
||||
KernelArg.
|
||||
"""
|
||||
|
||||
def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool):
|
||||
self.num = num
|
||||
self._param = param
|
||||
self.do_not_specialize = do_not_specialize
|
||||
|
||||
@cached_property
|
||||
def name(self):
|
||||
return self._param.name
|
||||
|
||||
@cached_property
|
||||
def annotation(self):
|
||||
if not self._param.annotation or self._param.annotation == inspect.Parameter.empty:
|
||||
return ""
|
||||
return _normalize_ty(self._param.annotation)
|
||||
|
||||
@cached_property
|
||||
def is_constexpr(self):
|
||||
return "constexpr" in self.annotation
|
||||
|
||||
@property
|
||||
def default(self):
|
||||
return self._param.default
|
||||
|
||||
@property
|
||||
def has_default(self):
|
||||
return self._param.default != inspect.Parameter.empty
|
||||
|
||||
|
||||
class KernelArg:
|
||||
"""Represents an argument to a @jit'ed function.
|
||||
|
||||
An argument is a parameter plus a value.
|
||||
"""
|
||||
|
||||
def __init__(self, value, param):
|
||||
self.value = value
|
||||
self.param = param
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.param.name
|
||||
|
||||
def signature_key(self):
|
||||
annotation = self.param.annotation
|
||||
if "Tensor" in annotation:
|
||||
return self.value.dtype
|
||||
elif annotation == "bool":
|
||||
return "i1"
|
||||
elif annotation == "float":
|
||||
return "fp32"
|
||||
else:
|
||||
return JITFunction._key_of(self.value)
|
||||
|
||||
def specialization_key(self):
|
||||
assert not self.param.do_not_specialize
|
||||
|
||||
try:
|
||||
return (self.value.data_ptr() % JITFunction.divisibility == 0, )
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if isinstance(self.value, int):
|
||||
# bool is a subclass of int, so we don't check explicitly above.
|
||||
return (
|
||||
self.value % JITFunction.divisibility == 0,
|
||||
self.value % JITFunction.divisibility_8 == 0,
|
||||
self.value == 1,
|
||||
)
|
||||
|
||||
return (False, )
|
||||
|
||||
|
||||
class KernelInterface(Generic[T]):
|
||||
run: T
|
||||
|
||||
@@ -152,7 +203,6 @@ class KernelInterface(Generic[T]):
|
||||
|
||||
|
||||
class JITFunction(KernelInterface[T]):
|
||||
|
||||
# Hook for inspecting compiled functions and modules
|
||||
cache_hook = None
|
||||
divisibility = 16
|
||||
@@ -169,44 +219,44 @@ class JITFunction(KernelInterface[T]):
|
||||
elif isinstance(arg, bool):
|
||||
return "i1"
|
||||
elif isinstance(arg, int):
|
||||
if -2**31 <= arg and arg <= 2**31 - 1:
|
||||
if -(2**31) <= arg and arg <= 2**31 - 1:
|
||||
return "i32"
|
||||
elif 2**63 <= arg and arg <= 2**64 - 1:
|
||||
return "u64"
|
||||
else:
|
||||
return "i64"
|
||||
elif isinstance(arg, float):
|
||||
return 'fp32'
|
||||
return "fp32"
|
||||
elif arg is None:
|
||||
return None
|
||||
else:
|
||||
raise TypeError(f'Unsupported type {type(arg)} for {arg}')
|
||||
raise TypeError(f"Unsupported type {type(arg)} for {arg}")
|
||||
|
||||
@staticmethod
|
||||
def _device_of(arg):
|
||||
if hasattr(arg, "device"):
|
||||
if hasattr(arg.device, 'type'):
|
||||
return arg.device.type
|
||||
|
||||
return ''
|
||||
try:
|
||||
return arg.device.type
|
||||
except AttributeError:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _pinned_memory_of(arg):
|
||||
if hasattr(arg, "is_pinned"):
|
||||
if isinstance(arg.is_pinned, Callable):
|
||||
return arg.is_pinned()
|
||||
|
||||
return False
|
||||
try:
|
||||
return arg.is_pinned()
|
||||
except (AttributeError, TypeError):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _spec_of(arg):
|
||||
if hasattr(arg, "data_ptr"):
|
||||
return (arg.data_ptr() % JITFunction.divisibility == 0)
|
||||
return arg.data_ptr() % JITFunction.divisibility == 0
|
||||
elif isinstance(arg, int):
|
||||
return (arg % 16 == 0, arg == 1)
|
||||
return (arg is None, )
|
||||
|
||||
# TODO(jlebar): Fold this into the KernelArg class.
|
||||
def _get_config(self, *args):
|
||||
|
||||
def is_divisible_by_16(x):
|
||||
if hasattr(x, "data_ptr"):
|
||||
return x.data_ptr() % JITFunction.divisibility == 0
|
||||
@@ -222,28 +272,38 @@ class JITFunction(KernelInterface[T]):
|
||||
if x is None:
|
||||
return True
|
||||
return False
|
||||
divisible_by_16 = {i for i, arg in enumerate(
|
||||
args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
|
||||
divisible_by_8 = {i for i, arg in enumerate(
|
||||
args) if is_divisible_by_8(arg) and i not in self.do_not_specialize}
|
||||
|
||||
divisible_by_16 = {
|
||||
param.num
|
||||
for param, arg in zip(self.params, args)
|
||||
if is_divisible_by_16(arg) and not param.do_not_specialize
|
||||
}
|
||||
divisible_by_8 = {
|
||||
param.num
|
||||
for param, arg in zip(self.params, args)
|
||||
if is_divisible_by_8(arg) and not param.do_not_specialize
|
||||
}
|
||||
equal_to_1 = {
|
||||
i for i, arg in enumerate(args) if isinstance(
|
||||
arg, int) and not isinstance(
|
||||
arg, bool) and arg == 1 and i not in self.do_not_specialize}
|
||||
param.num
|
||||
for param, arg in zip(self.params, args)
|
||||
if isinstance(arg, int) and not isinstance(arg, bool) and arg == 1 and not param.do_not_specialize
|
||||
}
|
||||
# folded equal_to_1 and None
|
||||
# TODO: method to collect all folded args
|
||||
none_args = {i for i, arg in enumerate(args) if arg is None and i not in self.do_not_specialize}
|
||||
none_args = {param.num for param, arg in zip(self.params, args) if arg is None and not param.do_not_specialize}
|
||||
ids_of_folded_args = equal_to_1 | none_args
|
||||
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])(
|
||||
tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args), tuple(divisible_by_8))
|
||||
return namedtuple("instance_descriptor",
|
||||
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])( #
|
||||
tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args),
|
||||
tuple(divisible_by_8))
|
||||
# return _triton.code_gen.instance_descriptor(divisible_by_16,
|
||||
# equal_to_1)
|
||||
|
||||
@staticmethod
|
||||
def _type_of(key):
|
||||
# None are nullptr -- implicitly converted to *i8
|
||||
# `None` is nullptr. Implicitly convert to *i8.
|
||||
if key is None:
|
||||
return '*i8'
|
||||
return "*i8"
|
||||
dtype_str = str(key).split(".")[-1]
|
||||
tys = {
|
||||
"bool": "i1",
|
||||
@@ -281,187 +341,265 @@ class JITFunction(KernelInterface[T]):
|
||||
constants = dict(zip(self.constexprs, constexpr_key))
|
||||
return constants
|
||||
|
||||
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization,enable_fp_fusion, extern_libs, configs):
|
||||
def _call_hook(
|
||||
self,
|
||||
key,
|
||||
signature,
|
||||
device,
|
||||
constants,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
num_stages,
|
||||
waves_per_eu,
|
||||
matrix_instr_nonkdim,
|
||||
enable_warp_specialization,
|
||||
enable_fp_fusion,
|
||||
extern_libs,
|
||||
configs,
|
||||
):
|
||||
if JITFunction.cache_hook is None:
|
||||
return False
|
||||
|
||||
name = self.fn.__name__
|
||||
module = self.fn.__module__
|
||||
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
|
||||
arg_reprs = ', '.join([f'{param.name}: {ty}' for param, ty in zip(self.params, key[1])])
|
||||
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, matrix_instr_nonkdim={matrix_instr_nonkdim}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs}), enable_fp_fusion={enable_fp_fusion}]({arg_reprs})"
|
||||
key = str(key)
|
||||
|
||||
class LegacyCompiler:
|
||||
|
||||
def __init__(self, module, name):
|
||||
self.module = module
|
||||
self.name = name
|
||||
pass
|
||||
|
||||
kwargs = dict(signature=signature, device=device, constants=constants,
|
||||
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs,
|
||||
configs=configs)
|
||||
kwargs = dict(
|
||||
signature=signature,
|
||||
device=device,
|
||||
constants=constants,
|
||||
num_warps=num_warps,
|
||||
num_ctas=num_ctas,
|
||||
num_stages=num_stages,
|
||||
waves_per_eu=waves_per_eu,
|
||||
enable_warp_specialization=enable_warp_specialization,
|
||||
enable_fp_fusion=enable_fp_fusion,
|
||||
extern_libs=extern_libs,
|
||||
configs=configs)
|
||||
|
||||
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={
|
||||
"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False)
|
||||
|
||||
def _get_arg_specialization_key(self, arg_name, arg):
|
||||
arg_annotation = self.__annotations__.get(arg_name, '')
|
||||
if arg_annotation == '':
|
||||
return (arg.data_ptr() % JITFunction.divisibility == 0) if hasattr(arg, "data_ptr") \
|
||||
else (arg % JITFunction.divisibility == 0, arg % JITFunction.divisibility_8 == 0, arg == 1) if isinstance(arg, int) \
|
||||
else (False,)
|
||||
elif 'Tensor' in arg_annotation:
|
||||
return (arg.data_ptr() % JITFunction.divisibility == 0)
|
||||
elif 'int' in arg_annotation or 'bool' in arg_annotation:
|
||||
return (arg % JITFunction.divisibility == 0, arg % JITFunction.divisibility_8 == 0, arg == 1)
|
||||
else:
|
||||
return (False,)
|
||||
|
||||
def _get_arg_sig_key(self, arg_name, arg) -> str:
|
||||
arg_annotation = self.__annotations__.get(arg_name, '')
|
||||
if 'Tensor' in arg_annotation:
|
||||
return arg.dtype
|
||||
elif arg_annotation == 'bool':
|
||||
return "i1"
|
||||
elif arg_annotation == 'float':
|
||||
return 'fp32'
|
||||
else:
|
||||
return self._key_of(arg)
|
||||
return JITFunction.cache_hook(
|
||||
key=key,
|
||||
repr=repr,
|
||||
fn=LegacyCompiler(module, name),
|
||||
compile={"key": key, **kwargs},
|
||||
is_manual_warmup=False,
|
||||
already_compiled=False,
|
||||
)
|
||||
|
||||
def _conclude_device_type(self, device_types: List[str], pinned_memory_flags: List[bool]) -> str:
|
||||
device_types = [device_type for device_type in device_types if device_type != '']
|
||||
device_types = [device_type for device_type in device_types if device_type != ""]
|
||||
# Return cuda if one of the input tensors is cuda
|
||||
if 'cuda' in device_types:
|
||||
if "cuda" in device_types:
|
||||
import torch
|
||||
return 'hip' if torch.version.hip else 'cuda'
|
||||
|
||||
is_cpu = all(device_type == 'cpu' for device_type in device_types)
|
||||
return "hip" if torch.version.hip else "cuda"
|
||||
|
||||
is_cpu = all(device_type == "cpu" for device_type in device_types)
|
||||
is_pinned_memory = any(pinned_memory_flag for pinned_memory_flag in pinned_memory_flags)
|
||||
# Return cuda if all the input tensors are cpu while the memory is pinned
|
||||
if is_cpu and is_pinned_memory:
|
||||
return 'cuda'
|
||||
return "cuda"
|
||||
|
||||
return device_types[0] if len(device_types) > 0 else 'cuda'
|
||||
return device_types[0] if len(device_types) > 0 else "cuda"
|
||||
|
||||
def _make_launcher(self):
|
||||
regular_args = [arg for i, arg in enumerate(
|
||||
self.arg_names) if i not in self.constexprs]
|
||||
constexpr_args = [arg for i, arg in enumerate(
|
||||
self.arg_names) if i in self.constexprs]
|
||||
def run(self, *args, **kwargs):
|
||||
from ..compiler import CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps
|
||||
|
||||
def regular_args_v(args_proxy):
|
||||
return [args_proxy[arg_name] for arg_name in regular_args]
|
||||
# Get a compiler-flags arg like `num_warps` and remove it from kwargs.
|
||||
def get_special_arg(name: str, default=None):
|
||||
if name not in kwargs:
|
||||
return default
|
||||
ret = kwargs[name]
|
||||
del kwargs[name]
|
||||
return ret
|
||||
|
||||
def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, enable_fp_fusion, extern_libs, stream, warmup, device, device_type):
|
||||
from ..compiler import (CompiledKernel, compile,
|
||||
get_arch_default_num_stages,
|
||||
get_arch_default_num_warps)
|
||||
sig_key = tuple([self._get_arg_sig_key(arg_name, args_proxy[arg_name]) for arg_name in regular_args])
|
||||
constexpr_key = tuple([args_proxy[arg_name] for arg_name in constexpr_args])
|
||||
specializations = []
|
||||
for i, arg_name in enumerate(regular_args):
|
||||
if i in self.do_not_specialize:
|
||||
continue
|
||||
specializations += [self._get_arg_specialization_key(arg_name, args_proxy[arg_name])]
|
||||
grid = get_special_arg("grid")
|
||||
num_warps = get_special_arg("num_warps")
|
||||
num_ctas = get_special_arg("num_ctas", 1)
|
||||
num_stages = get_special_arg("num_stages")
|
||||
waves_per_eu = get_special_arg("waves_per_eu", 0)
|
||||
matrix_instr_nonkdim = get_special_arg("matrix_instr_nonkdim", 0)
|
||||
enable_warp_specialization = get_special_arg("enable_warp_specialization", False)
|
||||
enable_fp_fusion = get_special_arg("enable_fp_fusion", True)
|
||||
extern_libs = get_special_arg("extern_libs")
|
||||
stream = get_special_arg("stream")
|
||||
warmup = get_special_arg("warmup", False)
|
||||
device = get_special_arg("device")
|
||||
device_type = get_special_arg("device_type")
|
||||
|
||||
spec_key = tuple(specializations)
|
||||
assert num_ctas > 0
|
||||
assert grid is not None
|
||||
if callable(grid):
|
||||
grid = grid(args_proxy)
|
||||
grid_size = len(grid)
|
||||
grid_0 = grid[0]
|
||||
grid_1 = grid[1] if grid_size > 1 else 1
|
||||
grid_2 = grid[2] if grid_size > 2 else 1
|
||||
if device_type is None:
|
||||
device_types = [self._device_of(arg) for arg in regular_args_v(args_proxy)]
|
||||
device_types = [_device_type for _device_type in device_types if _device_type != '']
|
||||
device_type = self._conclude_device_type(device_types, [self._pinned_memory_of(arg) for arg in
|
||||
regular_args_v(args_proxy)])
|
||||
# Bind the remaining arguments to `fn`.
|
||||
bound_args = self.signature.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
|
||||
device_backend = None
|
||||
if device_type not in ['cuda']:
|
||||
device_backend = get_backend(device_type)
|
||||
if device_backend is None:
|
||||
raise ValueError('Cannot find backend for ' + device_type)
|
||||
assert len(bound_args.arguments) == len(self.params)
|
||||
args = [KernelArg(arg_value, param) for (_, arg_value), param in zip(bound_args.arguments.items(), self.params)]
|
||||
|
||||
if device is None:
|
||||
if device_type in ['cuda']:
|
||||
device = get_current_device()
|
||||
set_current_device(device)
|
||||
else:
|
||||
device = device_backend.get_current_device()
|
||||
device_backend.set_current_device(device)
|
||||
if stream is None and not warmup:
|
||||
if device_type in ['cuda']:
|
||||
stream = get_cuda_stream(device)
|
||||
else:
|
||||
stream = device_backend.get_stream()
|
||||
non_constexpr_arg_values = [arg.value for arg in args if not arg.param.is_constexpr]
|
||||
|
||||
if num_warps is None:
|
||||
num_warps = get_arch_default_num_warps(device_type)
|
||||
if num_stages is None:
|
||||
num_stages = get_arch_default_num_stages(device_type)
|
||||
sig_key = tuple(arg.signature_key() for arg in args if not arg.param.is_constexpr)
|
||||
spec_key = tuple(arg.specialization_key() for arg in args if not arg.param.do_not_specialize)
|
||||
constexpr_key = tuple(arg.value for arg in args if arg.param.is_constexpr)
|
||||
|
||||
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, enable_fp_fusion, self.debug)
|
||||
if extern_libs is not None:
|
||||
key = (key, tuple(extern_libs.items()))
|
||||
assert num_ctas > 0
|
||||
assert grid is not None
|
||||
if callable(grid):
|
||||
# Arguments are passed as a dict to `grid`, by contract.
|
||||
# TODO(jlebar): In the new launch API, pass the compiler flags as a
|
||||
# second parameter to `grid`.
|
||||
grid = grid(dict(bound_args.arguments))
|
||||
grid_size = len(grid)
|
||||
grid_0 = grid[0]
|
||||
grid_1 = grid[1] if grid_size > 1 else 1
|
||||
grid_2 = grid[2] if grid_size > 2 else 1
|
||||
if device_type is None:
|
||||
device_types = [self._device_of(arg) for arg in non_constexpr_arg_values]
|
||||
device_types = [_device_type for _device_type in device_types if _device_type != ""]
|
||||
device_type = self._conclude_device_type(device_types,
|
||||
[self._pinned_memory_of(arg) for arg in non_constexpr_arg_values])
|
||||
|
||||
bin = self.cache[device].get(key, None)
|
||||
if bin is not None:
|
||||
# build dict of constant values
|
||||
args = regular_args_v(args_proxy)
|
||||
# Create tensormaps and append to args
|
||||
args = bin.assemble_tensormap_to_arg(args)
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
|
||||
return bin
|
||||
# kernel not cached -- compile
|
||||
device_backend = None
|
||||
if device_type not in ["cuda"]:
|
||||
device_backend = get_backend(device_type)
|
||||
if device_backend is None:
|
||||
raise ValueError("Cannot find backend for " + device_type)
|
||||
|
||||
if device is None:
|
||||
if device_type in ["cuda"]:
|
||||
device = get_current_device()
|
||||
set_current_device(device)
|
||||
else:
|
||||
# build dict of constant values
|
||||
args = regular_args_v(args_proxy)
|
||||
all_args = tuple([args_proxy[arg_name] for arg_name in self.arg_names])
|
||||
configs = self._get_config(*all_args),
|
||||
constants = self._make_constants(constexpr_key)
|
||||
constants.update({i: None for i, arg in enumerate(all_args) if arg is None})
|
||||
constants.update({i: 1 for i in configs[0].equal_to_1})
|
||||
# build kernel signature -- doesn't include specialized arguments
|
||||
signature = {i: self._type_of(self._key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs}
|
||||
# build stub signature -- includes arguments that are specialized
|
||||
for i, arg in constants.items():
|
||||
if callable(arg):
|
||||
raise TypeError(f"Callable constexpr at index {i} is not supported")
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, enable_fp_fusion, extern_libs, configs):
|
||||
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=matrix_instr_nonkdim, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
|
||||
# Create tensormaps and append to args
|
||||
args = bin.assemble_tensormap_to_arg(args)
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
|
||||
self.cache[device][key] = bin
|
||||
return bin
|
||||
device = device_backend.get_current_device()
|
||||
device_backend.set_current_device(device)
|
||||
if stream is None and not warmup:
|
||||
if device_type in ["cuda"]:
|
||||
stream = get_cuda_stream(device)
|
||||
else:
|
||||
stream = device_backend.get_stream()
|
||||
|
||||
if num_warps is None:
|
||||
num_warps = get_arch_default_num_warps(device_type)
|
||||
if num_stages is None:
|
||||
num_stages = get_arch_default_num_stages(device_type)
|
||||
|
||||
if device_type in ["cuda"]:
|
||||
version_key = get_cuda_version_key()
|
||||
else:
|
||||
version_key = device_backend.get_version_key()
|
||||
key = (
|
||||
version_key,
|
||||
sig_key,
|
||||
constexpr_key,
|
||||
spec_key,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
num_stages,
|
||||
waves_per_eu,
|
||||
matrix_instr_nonkdim,
|
||||
enable_warp_specialization,
|
||||
enable_fp_fusion,
|
||||
self.debug,
|
||||
)
|
||||
if extern_libs is not None:
|
||||
key = (key, tuple(extern_libs.items()))
|
||||
|
||||
# Kernel is not cached; we have to compile.
|
||||
if key not in self.cache[device]:
|
||||
configs = (self._get_config(*[arg.value for arg in args]), )
|
||||
constants = {
|
||||
arg.param.num: arg.value
|
||||
for arg in args
|
||||
if arg.param.is_constexpr or arg.param.num in configs[0].equal_to_1 or arg.value is None
|
||||
}
|
||||
for i, arg in constants.items():
|
||||
if callable(arg):
|
||||
raise TypeError(f"Callable constexpr at index {i} is not supported")
|
||||
|
||||
# Build kernel signature -- doesn't include constexpr arguments.
|
||||
signature = {
|
||||
arg.param.num: self._type_of(self._key_of(arg.value))
|
||||
for arg in args
|
||||
if not arg.param.is_constexpr
|
||||
}
|
||||
|
||||
if self._call_hook(
|
||||
key,
|
||||
signature,
|
||||
device,
|
||||
constants,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
num_stages,
|
||||
waves_per_eu,
|
||||
matrix_instr_nonkdim,
|
||||
enable_warp_specialization,
|
||||
enable_fp_fusion,
|
||||
extern_libs,
|
||||
configs,
|
||||
):
|
||||
return None
|
||||
|
||||
# create a wrapper to call launcher_body
|
||||
args_map = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
|
||||
args_signature = ', '.join(name if dflt == inspect._empty else f'{name} = triton.language.dtype(\'{dflt}\')' if dtype.is_dtype(f'{dflt}') else f'{name} = {dflt}' for name, dflt in zip(self.arg_names, self.arg_defaults))
|
||||
args_signature = args_signature + ', ' if len(args_signature) > 0 else ''
|
||||
src = f"""
|
||||
import triton
|
||||
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, waves_per_eu=0, matrix_instr_nonkdim=0, enable_warp_specialization=False, enable_fp_fusion=True, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
return launcher_body({{{args_map}}}, grid, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, enable_fp_fusion, extern_libs, stream, warmup, device, device_type)
|
||||
"""
|
||||
scope = {"launcher_body": launcher_body}
|
||||
exec(src, scope)
|
||||
return scope[self.fn.__name__]
|
||||
self.cache[device][key] = compile(
|
||||
self,
|
||||
signature=signature,
|
||||
device=device,
|
||||
constants=constants,
|
||||
num_warps=num_warps,
|
||||
num_ctas=num_ctas,
|
||||
num_stages=num_stages,
|
||||
waves_per_eu=waves_per_eu,
|
||||
matrix_instr_nonkdim=matrix_instr_nonkdim,
|
||||
enable_warp_specialization=enable_warp_specialization,
|
||||
enable_fp_fusion=enable_fp_fusion,
|
||||
extern_libs=extern_libs,
|
||||
configs=configs,
|
||||
debug=self.debug,
|
||||
device_type=device_type,
|
||||
)
|
||||
|
||||
bin = self.cache[device][key]
|
||||
if not warmup:
|
||||
bin.c_wrapper(
|
||||
grid_0,
|
||||
grid_1,
|
||||
grid_2,
|
||||
bin.num_warps,
|
||||
bin.num_ctas,
|
||||
bin.clusterDims[0],
|
||||
bin.clusterDims[1],
|
||||
bin.clusterDims[2],
|
||||
bin.shared,
|
||||
stream,
|
||||
bin.cu_function,
|
||||
CompiledKernel.launch_enter_hook,
|
||||
CompiledKernel.launch_exit_hook,
|
||||
bin,
|
||||
*bin.assemble_tensormap_to_arg(non_constexpr_arg_values),
|
||||
)
|
||||
return bin
|
||||
|
||||
def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None):
|
||||
do_not_specialize = do_not_specialize if do_not_specialize else []
|
||||
|
||||
self.fn = fn
|
||||
self.module = fn.__module__
|
||||
self.version = version
|
||||
# function signature information
|
||||
signature = inspect.signature(fn)
|
||||
self.arg_names = [v.name for v in signature.parameters.values()]
|
||||
self.arg_defaults = [v.default for v in signature.parameters.values()]
|
||||
self.has_defaults = any(v != inspect._empty for v in self.arg_defaults)
|
||||
self.signature = inspect.signature(fn)
|
||||
self.do_not_specialize = do_not_specialize
|
||||
|
||||
self.params = []
|
||||
for i, param in enumerate(self.signature.parameters.values()):
|
||||
dns = do_not_specialize and (i in do_not_specialize or param.name in do_not_specialize)
|
||||
self.params.append(KernelParam(i, param, dns))
|
||||
|
||||
# function source code (without decorators)
|
||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||
self.src = self.src[self.src.find("def"):]
|
||||
@@ -470,22 +608,18 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
|
||||
self.hash = None
|
||||
# JITFunction can be instantiated as kernel
|
||||
# when called with a grid using __getitem__
|
||||
self.kernel_decorators = []
|
||||
self.kernel = None
|
||||
self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug
|
||||
self.noinline = noinline
|
||||
# annotations
|
||||
self.__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
|
||||
# index of constexprs
|
||||
self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if 'constexpr' in ty]
|
||||
# specialization hints
|
||||
regular_args = [arg for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
|
||||
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
|
||||
self.do_not_specialize = {regular_args.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize}
|
||||
|
||||
# tma info
|
||||
self.tensormaps_info = TMAInfos()
|
||||
# launcher
|
||||
self.run = self._make_launcher()
|
||||
|
||||
# TODO(jlebar): Remove uses of these fields outside this file, then
|
||||
# remove the fields here.
|
||||
self.arg_names = [p.name for p in self.params]
|
||||
self.constexprs = [p.num for p in self.params if p.is_constexpr]
|
||||
|
||||
# re-use docs of wrapped function
|
||||
self.__doc__ = fn.__doc__
|
||||
self.__name__ = fn.__name__
|
||||
@@ -498,7 +632,7 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
|
||||
if self.hash is None:
|
||||
dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src)
|
||||
dependencies_finder.visit(self.parse())
|
||||
self.hash = dependencies_finder.ret + version_key()
|
||||
self.hash = dependencies_finder.ret
|
||||
return self.hash
|
||||
|
||||
def warmup(self, *args, **kwargs):
|
||||
@@ -518,14 +652,10 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
|
||||
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
# - when kernel decorators change, cached kernel
|
||||
# needs to be cleared
|
||||
if name == 'kernel_decorators':
|
||||
self.kernel = None
|
||||
super(JITFunction, self).__setattr__(name, value)
|
||||
# - when `.src` attribute is set, cache path needs
|
||||
# to be reinitialized
|
||||
if name == 'src':
|
||||
if name == "src":
|
||||
self.hash = None
|
||||
|
||||
def __repr__(self):
|
||||
@@ -591,12 +721,14 @@ def jit(
|
||||
debug=debug,
|
||||
noinline=noinline,
|
||||
)
|
||||
|
||||
if fn is not None:
|
||||
return decorator(fn)
|
||||
|
||||
else:
|
||||
return decorator
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Utilities for mocking tensors
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -607,10 +739,10 @@ class MockTensor:
|
||||
Can be used in place of real tensors when calling:
|
||||
kernel.warmup(MockTensor(torch.float32), ...)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def wrap_dtype(arg):
|
||||
if arg.__class__.__name__ == "dtype" and\
|
||||
arg.__module__ == "torch":
|
||||
if arg.__class__.__name__ == "dtype" and arg.__module__ == "torch":
|
||||
return MockTensor(arg)
|
||||
return arg
|
||||
|
||||
@@ -623,6 +755,7 @@ class MockTensor:
|
||||
|
||||
|
||||
class TensorWrapper:
|
||||
|
||||
def __init__(self, base, dtype):
|
||||
self.dtype = dtype
|
||||
self.base = base
|
||||
@@ -637,7 +770,7 @@ class TensorWrapper:
|
||||
return self.base.stride(i)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'TensorWrapper[{self.dtype}]({self.base})'
|
||||
return f"TensorWrapper[{self.dtype}]({self.base})"
|
||||
|
||||
def element_size(self):
|
||||
return self.base.element_size()
|
||||
@@ -655,4 +788,4 @@ def reinterpret(tensor, dtype):
|
||||
# A new wrapper is needed around an unwrapped tensor.
|
||||
return TensorWrapper(tensor, dtype)
|
||||
else:
|
||||
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')
|
||||
raise TypeError(f"Cannot reinterpret a {type(tensor)}.")
|
||||
|
||||
@@ -78,10 +78,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None):
|
||||
return torch.mean(torch.tensor(ret)).item()
|
||||
|
||||
|
||||
def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
|
||||
quantiles=None,
|
||||
fast_flush=True,
|
||||
return_mode="mean"):
|
||||
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean"):
|
||||
assert return_mode in ["min", "max", "mean", "median"]
|
||||
import torch
|
||||
"""
|
||||
@@ -261,11 +258,12 @@ class Benchmark:
|
||||
|
||||
|
||||
class Mark:
|
||||
|
||||
def __init__(self, fn, benchmarks):
|
||||
self.fn = fn
|
||||
self.benchmarks = benchmarks
|
||||
|
||||
def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, **kwrags):
|
||||
def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False, **kwrags):
|
||||
import os
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
@@ -321,24 +319,36 @@ class Mark:
|
||||
if save_path:
|
||||
plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
|
||||
df = df[x_names + bench.line_names]
|
||||
if diff_col and df.shape[1] == 2:
|
||||
col0, col1 = df.columns.tolist()
|
||||
df['Diff'] = df[col1] - df[col0]
|
||||
|
||||
if print_data:
|
||||
print(bench.plot_name + ':')
|
||||
print(df)
|
||||
if save_path:
|
||||
df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False)
|
||||
return df
|
||||
|
||||
def run(self, show_plots=False, print_data=False, save_path='', **kwargs):
|
||||
def run(self, show_plots=False, print_data=False, save_path='', return_df=False, **kwargs):
|
||||
has_single_bench = isinstance(self.benchmarks, Benchmark)
|
||||
benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks
|
||||
result_dfs = []
|
||||
if save_path:
|
||||
html = open(os.path.join(save_path, "results.html"), "w")
|
||||
html.write("<html><body>\n")
|
||||
for bench in benchmarks:
|
||||
self._run(bench, save_path, show_plots, print_data, **kwargs)
|
||||
result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs))
|
||||
if save_path:
|
||||
html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
|
||||
if save_path:
|
||||
html.write("</body></html>\n")
|
||||
if return_df:
|
||||
if has_single_bench:
|
||||
return result_dfs[0]
|
||||
else:
|
||||
return result_dfs
|
||||
return None
|
||||
|
||||
|
||||
def perf_report(benchmarks):
|
||||
@@ -393,12 +403,15 @@ def get_max_tensorcore_tflops(dtype, clock_rate, backend=None, device=None):
|
||||
tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
|
||||
return tflops
|
||||
|
||||
|
||||
# create decorator that wraps test function into
|
||||
# a cuda-memcheck system call
|
||||
|
||||
|
||||
def cuda_memcheck(**target_kwargs):
|
||||
|
||||
def decorator(test_fn):
|
||||
|
||||
@functools.wraps(test_fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
import psutil
|
||||
@@ -416,7 +429,9 @@ def cuda_memcheck(**target_kwargs):
|
||||
assert "ERROR SUMMARY: 0 errors" in str(out.stdout)
|
||||
else:
|
||||
test_fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@@ -424,22 +439,18 @@ def cuda_memcheck(**target_kwargs):
|
||||
def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215):
|
||||
try:
|
||||
subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"])
|
||||
subprocess.check_output(
|
||||
[
|
||||
"nvidia-smi",
|
||||
"-i",
|
||||
"0",
|
||||
f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}",
|
||||
]
|
||||
)
|
||||
subprocess.check_output(
|
||||
[
|
||||
"nvidia-smi",
|
||||
"-i",
|
||||
"0",
|
||||
f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}",
|
||||
]
|
||||
)
|
||||
subprocess.check_output([
|
||||
"nvidia-smi",
|
||||
"-i",
|
||||
"0",
|
||||
f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}",
|
||||
])
|
||||
subprocess.check_output([
|
||||
"nvidia-smi",
|
||||
"-i",
|
||||
"0",
|
||||
f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}",
|
||||
])
|
||||
cur_sm_clock = nvsmi(["clocks.current.sm"])[0]
|
||||
cur_mem_clock = nvsmi(["clocks.current.memory"])[0]
|
||||
assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz"
|
||||
|
||||
19
python/triton/third_party/hip/hip_backend.py
vendored
19
python/triton/third_party/hip/hip_backend.py
vendored
@@ -8,8 +8,8 @@ from typing import Any, Tuple
|
||||
|
||||
|
||||
from triton.common import _build
|
||||
from triton.common.backend import BaseBackend, register_backend
|
||||
from triton.compiler.make_launcher import get_cache_manager, version_key, make_so_cache_key
|
||||
from triton.common.backend import BaseBackend, register_backend, compute_core_version_key
|
||||
from triton.compiler.make_launcher import get_cache_manager, make_so_cache_key
|
||||
from triton.compiler.utils import generate_cu_signature
|
||||
from triton.runtime import jit
|
||||
from triton.runtime.driver import HIPDriver
|
||||
@@ -25,7 +25,7 @@ else:
|
||||
|
||||
def make_stub(name, signature, constants, ids, **kwargs):
|
||||
# name of files that are cached
|
||||
so_cache_key = make_so_cache_key(version_key(), signature, constants, ids, **kwargs)
|
||||
so_cache_key = make_so_cache_key(compute_core_version_key(), signature, constants, ids, **kwargs)
|
||||
so_cache_manager = get_cache_manager(so_cache_key)
|
||||
so_name = f"{name}.so"
|
||||
# retrieve stub from cache if it exists
|
||||
@@ -414,11 +414,21 @@ def llir_to_amdgcn_and_hsaco(mod: Any, gfx_arch: str, gfx_triple: str, gfx_featu
|
||||
|
||||
|
||||
class HIPBackend(BaseBackend):
|
||||
_cached_rocm_version_key = None
|
||||
|
||||
def __init__(self, device_type: str) -> None:
|
||||
super(HIPBackend, self).__init__(device_type)
|
||||
self.driver = HIPDriver()
|
||||
self.stub_so_path = ""
|
||||
|
||||
def get_version_key(self):
|
||||
if self._cached_rocm_version_key is None:
|
||||
key = compute_core_version_key()
|
||||
### TODO: Append ROCM version here if needed
|
||||
|
||||
self._cached_rocm_version_key = key
|
||||
return self._cached_rocm_version_key
|
||||
|
||||
def is_standalone(self):
|
||||
return not HIP_BACKEND_MODE
|
||||
|
||||
@@ -500,7 +510,6 @@ class HIPBackend(BaseBackend):
|
||||
return arch
|
||||
|
||||
def make_launcher_stub(self, name, signature, constants, ids):
|
||||
# print("HIPBackend.make_launcher_stub")
|
||||
self.stub_so_path = make_stub(name, signature, constants, ids)
|
||||
return self.stub_so_path
|
||||
|
||||
@@ -517,4 +526,4 @@ class HIPBackend(BaseBackend):
|
||||
return _triton.get_num_warps(module)
|
||||
|
||||
def get_matrix_core_version(self):
|
||||
return gpu_matrix_core_version()
|
||||
return gpu_matrix_core_version()
|
||||
|
||||
@@ -141,8 +141,7 @@ class ExternLibrary(ABC):
|
||||
f.write(file_str)
|
||||
f.close()
|
||||
if self._format:
|
||||
subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file],
|
||||
stdout=subprocess.PIPE).communicate()
|
||||
subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file], stdout=subprocess.PIPE).communicate()
|
||||
subprocess.Popen(["isort", output_file], stdout=subprocess.PIPE).communicate()
|
||||
|
||||
|
||||
@@ -208,56 +207,36 @@ class Libdevice(ExternLibrary):
|
||||
|
||||
# Group functions together by renaming.
|
||||
renaming = {
|
||||
'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh',
|
||||
'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': 'add_rn',
|
||||
'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru',
|
||||
'dadd_rz': 'add_rz', 'fadd_rz': 'add_rz', 'asinf': 'asin',
|
||||
'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2',
|
||||
'atanhf': 'atanh', 'brevll': 'brev', 'cbrtf': 'cbrt',
|
||||
'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign',
|
||||
'cosf': 'cos', 'coshf': 'cosh', 'cospif': 'cospi',
|
||||
'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1',
|
||||
'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn',
|
||||
'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', 'ddiv_ru': 'div_ru',
|
||||
'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf',
|
||||
'erfcf': 'erfc', 'erfcinvf': 'erfcinv', 'erfcxf': 'erfcx',
|
||||
'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10',
|
||||
'exp2f': 'exp2', 'expm1f': 'expm1', 'fabsf': 'abs',
|
||||
'fabs': 'abs', 'fast_fdividef': 'fast_dividef',
|
||||
'fdimf': 'fdim', 'ffsll': 'ffs', 'floorf': 'floor',
|
||||
'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn',
|
||||
'fmaf_ru': 'fma_ru', 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod',
|
||||
'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb',
|
||||
'isinff': 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan',
|
||||
'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn',
|
||||
'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint',
|
||||
'llroundf': 'llround', 'logf': 'log', 'log10f': 'log10',
|
||||
'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb',
|
||||
'umax': 'max', 'llmax': 'max', 'ullmax': 'max', 'fmaxf': 'max',
|
||||
'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min',
|
||||
'fminf': 'min', 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd',
|
||||
'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', 'dmul_ru': 'mul_ru',
|
||||
'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz',
|
||||
'umul24': 'mul24', 'umulhi': 'mulhi', 'mul64hi': 'mulhi',
|
||||
'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': 'nextafter',
|
||||
'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf',
|
||||
'normcdfinvf': 'normcdfinv', 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow',
|
||||
'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', 'drcp_rd': 'rcp_rd',
|
||||
'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru',
|
||||
'drcp_ru': 'rcp_ru', 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz',
|
||||
'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot',
|
||||
'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d',
|
||||
'roundf': 'round', 'rsqrtf': 'rsqrt', 'frsqrt_rn': 'rsqrt_rn',
|
||||
'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit',
|
||||
'signbitd': 'signbit', 'sinf': 'sin', 'sinhf': 'sinh',
|
||||
'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd',
|
||||
'dsqrt_rd': 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn',
|
||||
'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', 'fsqrt_rz': 'sqrt_rz',
|
||||
'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd',
|
||||
'fsub_rn': 'sub_rn', 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru',
|
||||
'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz',
|
||||
'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc',
|
||||
'y0f': 'y0', 'y1f': 'y1', 'ynf': 'yn'
|
||||
'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh', 'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn':
|
||||
'add_rn', 'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru', 'dadd_rz': 'add_rz', 'fadd_rz':
|
||||
'add_rz', 'asinf': 'asin', 'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2', 'atanhf': 'atanh',
|
||||
'brevll': 'brev', 'cbrtf': 'cbrt', 'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign', 'cosf': 'cos',
|
||||
'coshf': 'cosh', 'cospif': 'cospi', 'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1',
|
||||
'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn', 'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru',
|
||||
'ddiv_ru': 'div_ru', 'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf', 'erfcf': 'erfc', 'erfcinvf':
|
||||
'erfcinv', 'erfcxf': 'erfcx', 'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10', 'exp2f': 'exp2',
|
||||
'expm1f': 'expm1', 'fabsf': 'abs', 'fabs': 'abs', 'fast_fdividef': 'fast_dividef', 'fdimf': 'fdim', 'ffsll':
|
||||
'ffs', 'floorf': 'floor', 'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn', 'fmaf_ru': 'fma_ru',
|
||||
'fmaf_rz': 'fma_rz', 'fmodf': 'fmod', 'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb', 'isinff':
|
||||
'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan', 'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn',
|
||||
'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint', 'llroundf': 'llround', 'logf': 'log', 'log10f':
|
||||
'log10', 'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb', 'umax': 'max', 'llmax': 'max', 'ullmax':
|
||||
'max', 'fmaxf': 'max', 'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min', 'fminf': 'min',
|
||||
'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd', 'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn',
|
||||
'dmul_ru': 'mul_ru', 'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz', 'umul24': 'mul24',
|
||||
'umulhi': 'mulhi', 'mul64hi': 'mulhi', 'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf':
|
||||
'nextafter', 'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf', 'normcdfinvf': 'normcdfinv',
|
||||
'popcll': 'popc', 'powif': 'pow', 'powi': 'pow', 'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd',
|
||||
'drcp_rd': 'rcp_rd', 'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru', 'drcp_ru': 'rcp_ru',
|
||||
'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz', 'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot',
|
||||
'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d', 'roundf': 'round', 'rsqrtf': 'rsqrt',
|
||||
'frsqrt_rn': 'rsqrt_rn', 'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit', 'signbitd': 'signbit',
|
||||
'sinf': 'sin', 'sinhf': 'sinh', 'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd', 'dsqrt_rd':
|
||||
'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn', 'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru',
|
||||
'fsqrt_rz': 'sqrt_rz', 'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd', 'fsub_rn': 'sub_rn',
|
||||
'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru', 'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz',
|
||||
'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc', 'y0f': 'y0', 'y1f': 'y1', 'ynf':
|
||||
'yn'
|
||||
}
|
||||
|
||||
for symbol in self._symbols.values():
|
||||
@@ -347,8 +326,7 @@ class LLVMDisassembler:
|
||||
self._ll_file = "/tmp/extern_lib.ll"
|
||||
|
||||
def disasm(self, lib_path: str) -> None:
|
||||
subprocess.Popen([self._path, lib_path, "-o", self.ll_file],
|
||||
stdout=subprocess.PIPE).communicate()
|
||||
subprocess.Popen([self._path, lib_path, "-o", self.ll_file], stdout=subprocess.PIPE).communicate()
|
||||
|
||||
@property
|
||||
def ll_file(self) -> str:
|
||||
|
||||
@@ -40,10 +40,13 @@ if __name__ == "__main__":
|
||||
|
||||
# command-line arguments
|
||||
parser = ArgumentParser(description=desc)
|
||||
parser.add_argument("path", help="Path to Python source containing desired kernel in its scope. File will be executed.")
|
||||
parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile", required=True)
|
||||
parser.add_argument("path",
|
||||
help="Path to Python source containing desired kernel in its scope. File will be executed.")
|
||||
parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile",
|
||||
required=True)
|
||||
parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel")
|
||||
parser.add_argument("--num-stages", "-ns", type=int, default=3, help="Number of stages (meta-parameter of the kernel)")
|
||||
parser.add_argument("--num-stages", "-ns", type=int, default=3,
|
||||
help="Number of stages (meta-parameter of the kernel)")
|
||||
parser.add_argument("--out-name", "-on", type=str, default=None, help="Out name for the compiled kernel")
|
||||
parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename")
|
||||
parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True)
|
||||
@@ -104,7 +107,8 @@ if __name__ == "__main__":
|
||||
config = triton.compiler.instance_descriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1)
|
||||
for i in equal_to_1:
|
||||
constexprs.update({i: 1})
|
||||
ccinfo = triton.compile(kernel, signature=signature, constants=constexprs, configs=[config], num_warps=args.num_warps, num_stages=args.num_stages)
|
||||
ccinfo = triton.compile(kernel, signature=signature, constants=constexprs, configs=[config],
|
||||
num_warps=args.num_warps, num_stages=args.num_stages)
|
||||
arg_names = []
|
||||
arg_types = []
|
||||
for i in signature.keys():
|
||||
|
||||
@@ -27,6 +27,7 @@ class KernelLinkerMeta:
|
||||
|
||||
|
||||
class HeaderParser:
|
||||
|
||||
def __init__(self) -> None:
|
||||
import re
|
||||
|
||||
@@ -42,7 +43,6 @@ class HeaderParser:
|
||||
self.kernels = defaultdict(list)
|
||||
|
||||
def extract_linker_meta(self, header: str):
|
||||
|
||||
for ln in header.splitlines():
|
||||
if ln.startswith("//"):
|
||||
m = self.linker_directives.match(ln)
|
||||
@@ -76,7 +76,7 @@ class HeaderParser:
|
||||
m = self.c_sig.findall(c_sig)
|
||||
if len(m):
|
||||
tys, args = [], []
|
||||
for (ty, arg_name) in m:
|
||||
for ty, arg_name in m:
|
||||
tys.append(ty)
|
||||
args.append(arg_name)
|
||||
return tys, args
|
||||
@@ -84,7 +84,7 @@ class HeaderParser:
|
||||
raise LinkerError(f"{c_sig} is not a valid argument signature")
|
||||
|
||||
def _match_suffix(self, suffix: str, c_sig: str):
|
||||
args = c_sig.split(',')
|
||||
args = c_sig.split(",")
|
||||
s2i = {"c": 1, "d": 16}
|
||||
num_specs = 0
|
||||
sizes = []
|
||||
@@ -110,7 +110,7 @@ class HeaderParser:
|
||||
if name in self.kernels:
|
||||
last: KernelLinkerMeta = self.kernels[name][-1]
|
||||
|
||||
for (cur, new_) in zip(last.arg_ctypes, ker.arg_ctypes):
|
||||
for cur, new_ in zip(last.arg_ctypes, ker.arg_ctypes):
|
||||
if cur != new_:
|
||||
raise LinkerError(
|
||||
f"Mismatched signature for kernel {name}: \n\texisting sig is: {','.join(last.arg_ctypes)}\n\tcurrent is: {','.join(ker.arg_ctypes)}"
|
||||
@@ -152,7 +152,7 @@ void unload_{meta.orig_kernel_name}();
|
||||
# generate dispatcher function for kernels with different meta-parameter and constant values
|
||||
def make_default_algo_kernel(meta: KernelLinkerMeta) -> str:
|
||||
src = f"CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}){{\n"
|
||||
src += f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n"
|
||||
src += (f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n")
|
||||
src += "}\n"
|
||||
return src
|
||||
|
||||
@@ -164,12 +164,22 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -
|
||||
src += f"CUresult {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, {gen_signature(meta)});\n"
|
||||
src += "\n"
|
||||
|
||||
src += f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{"
|
||||
src += (f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{")
|
||||
src += "\n"
|
||||
for meta in sorted(metas, key=lambda m: -m.num_specs):
|
||||
cond_fn = lambda val, hint: f"({val} % {hint} == 0)" if hint == 16 else f"({val} == {hint})" if hint == 1 else None
|
||||
conds = " && ".join([cond_fn(val, hint) for val, hint in zip(meta.arg_names, meta.sizes) if hint is not None])
|
||||
src += f" if ({conds})\n"
|
||||
cond_fn = ( #
|
||||
lambda val, hint: f"({val} % {hint} == 0)" #
|
||||
if hint == 16 #
|
||||
else f"({val} == {hint})" #
|
||||
if hint == 1 #
|
||||
else None)
|
||||
conds = " && ".join([ #
|
||||
cond_fn(val, hint) #
|
||||
for val, hint in zip(meta.arg_names, meta.sizes) #
|
||||
if hint is not None
|
||||
])
|
||||
src += (f" if ({conds})\n" if any(meta.sizes) else "if (1)\n"
|
||||
) # Edge case where no specializations hence no dispatching required
|
||||
arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1]
|
||||
src += f" return {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(stream, {', '.join(arg_names)});\n"
|
||||
src += "\n"
|
||||
@@ -183,7 +193,7 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -
|
||||
src += f"void {mode}_{name}() {{"
|
||||
src += "\n"
|
||||
for meta in sorted(metas, key=lambda m: -m.num_specs):
|
||||
src += f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n"
|
||||
src += (f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n")
|
||||
src += "}\n"
|
||||
return src
|
||||
|
||||
@@ -252,7 +262,12 @@ if __name__ == "__main__":
|
||||
help="Paths to header files to link. Must include linker directive annotations (autogenerated by ttc)",
|
||||
)
|
||||
parser.add_argument("--out", "-o", type=Path, help="Out filename")
|
||||
parser.add_argument("--prefix", type=str, default="", help="String to prefix kernel dispatcher names")
|
||||
parser.add_argument(
|
||||
"--prefix",
|
||||
type=str,
|
||||
default="",
|
||||
help="String to prefix kernel dispatcher names",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# metadata
|
||||
|
||||
Reference in New Issue
Block a user