mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Reformat Python code with yapf. (#2589)
I've add an option to yapf to do what we want for long lines, see https://github.com/google/yapf/pull/1177. We can now have a real Python formatter, yay! To make this PR, I ran my modified yapf over the repository, then looked over the full diff. Where yapf was mangling the param list of long function decls/calls (mostly kernels), I manually added `#` to put linebreaks where we want. I fixed up other formatting too -- mostly adding or removing a trailing comma from lists. Overall, trailing `#` was sufficient to get formatting similar to our current code. I didn't have to disable yapf anywhere. --------- Co-authored-by: Phil Tillet <phil@openai.com>
This commit is contained in:
@@ -45,12 +45,12 @@ __all__ = [
|
||||
"tools",
|
||||
]
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
# misc. utilities that don't fit well
|
||||
# into any specific module
|
||||
# -------------------------------------
|
||||
|
||||
|
||||
def cdiv(x: int, y: int):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import importlib
|
||||
@@ -16,6 +15,7 @@ TRITON_VERSION = "2.1.0"
|
||||
|
||||
|
||||
class BaseBackend:
|
||||
|
||||
def __init__(self, device_type: str) -> None:
|
||||
self.device_type = device_type
|
||||
|
||||
@@ -154,7 +154,7 @@ def compute_core_version_key():
|
||||
libtriton_hash = hashlib.sha1()
|
||||
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
|
||||
while True:
|
||||
chunk = f.read(1024 ** 2)
|
||||
chunk = f.read(1024**2)
|
||||
if not chunk:
|
||||
break
|
||||
libtriton_hash.update(chunk)
|
||||
|
||||
@@ -86,9 +86,15 @@ def _build(name, src, srcdir):
|
||||
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
|
||||
|
||||
if is_hip():
|
||||
ret = subprocess.check_call([cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{hip_lib_dir}", "-lamdhip64", "-o", so])
|
||||
ret = subprocess.check_call([
|
||||
cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC",
|
||||
f"-L{hip_lib_dir}", "-lamdhip64", "-o", so
|
||||
])
|
||||
else:
|
||||
cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]
|
||||
cc_cmd = [
|
||||
cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda",
|
||||
"-o", so
|
||||
]
|
||||
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
|
||||
ret = subprocess.check_call(cc_cmd)
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from .compiler import (CompiledKernel, compile, get_arch_default_num_stages,
|
||||
get_arch_default_num_warps, instance_descriptor)
|
||||
from .compiler import (CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps,
|
||||
instance_descriptor)
|
||||
from .errors import CompilationError
|
||||
|
||||
__all__ = ["compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps", "get_arch_default_num_stages"]
|
||||
__all__ = [
|
||||
"compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps",
|
||||
"get_arch_default_num_stages"
|
||||
]
|
||||
|
||||
@@ -10,8 +10,7 @@ from .._C.libtriton.triton import ir
|
||||
from ..language import constexpr, tensor
|
||||
# ideally we wouldn't need any runtime component
|
||||
from ..runtime import JITFunction
|
||||
from .errors import (CompilationError, CompileTimeAssertionFailure,
|
||||
UnsupportedLanguageConstruct)
|
||||
from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
|
||||
|
||||
|
||||
def mangle_ty(ty):
|
||||
@@ -68,7 +67,10 @@ def _check_fn_args(node, fn, args):
|
||||
if fn.noinline:
|
||||
for idx, arg in enumerate(args):
|
||||
if not _is_constexpr(arg) and not _is_triton_scalar(arg):
|
||||
raise UnsupportedLanguageConstruct(fn.src, node, f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}')
|
||||
raise UnsupportedLanguageConstruct(
|
||||
fn.src, node,
|
||||
f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}'
|
||||
)
|
||||
|
||||
|
||||
def _get_fn_file_line(fn):
|
||||
@@ -89,6 +91,7 @@ _condition_types = {bool, int, type(None)} # Python types accepted for conditio
|
||||
|
||||
|
||||
class enter_sub_region:
|
||||
|
||||
def __init__(self, generator):
|
||||
self.generator = generator
|
||||
|
||||
@@ -109,6 +112,7 @@ class enter_sub_region:
|
||||
|
||||
# Check if the given syntax node has an "early" return
|
||||
class ContainsReturnChecker(ast.NodeVisitor):
|
||||
|
||||
def __init__(self, gscope):
|
||||
self.gscope = gscope
|
||||
|
||||
@@ -199,9 +203,10 @@ class ContainsReturnChecker(ast.NodeVisitor):
|
||||
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, function_name, target,
|
||||
module=None, is_kernel=False, function_types: Optional[Dict] = None,
|
||||
debug=False, noinline=False, file_name: Optional[str] = None, begin_line=0):
|
||||
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, function_name, target, module=None,
|
||||
is_kernel=False, function_types: Optional[Dict] = None, debug=False, noinline=False,
|
||||
file_name: Optional[str] = None, begin_line=0):
|
||||
self.context = context
|
||||
self.builder = ir.builder(context)
|
||||
self.file_name = file_name
|
||||
@@ -237,8 +242,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
))
|
||||
|
||||
def _define_name_lookup(self):
|
||||
|
||||
def local_lookup(name: str, absent):
|
||||
value = self.lscope.get(name, absent) # this needs to be re-fetched from `self` every time, because it gets switched occasionally
|
||||
# this needs to be re-fetched from `self` every time, because it gets switched occasionally
|
||||
value = self.lscope.get(name, absent)
|
||||
if value is not absent and name not in self.local_defs:
|
||||
self.global_uses[name] = value
|
||||
return value
|
||||
@@ -255,8 +262,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
return name_lookup
|
||||
|
||||
def set_value(self, name: str,
|
||||
value: Union[tensor, constexpr]) -> None:
|
||||
def set_value(self, name: str, value: Union[tensor, constexpr]) -> None:
|
||||
''' This function:
|
||||
called by visit_Assign() & visit_FunctionDef() to store left value (lvalue)
|
||||
1. record local defined name (FIXME: should consider control flow)
|
||||
@@ -338,7 +344,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.visit(init_node)
|
||||
# initialize function
|
||||
visibility = "public" if self.is_kernel else "private"
|
||||
self.fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility, self.noinline)
|
||||
self.fn = self.builder.get_or_insert_function(self.module, self.function_name,
|
||||
self.prototype.to_ir(self.builder), visibility, self.noinline)
|
||||
self.module.push_back(self.fn)
|
||||
entry = self.fn.add_entry_block()
|
||||
arg_values = []
|
||||
@@ -469,12 +476,23 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
rhs = self.visit(node.right)
|
||||
method_name = self._method_name_for_bin_op.get(type(node.op))
|
||||
if method_name is None:
|
||||
raise UnsupportedLanguageConstruct(None, node, "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
||||
return self._apply_binary_method(method_name, lhs, rhs)
|
||||
|
||||
_method_name_for_bin_op: Dict[Type[ast.operator], str] = {
|
||||
ast.Add: '__add__', ast.Sub: '__sub__', ast.Mult: '__mul__', ast.Div: '__truediv__',
|
||||
ast.FloorDiv: '__floordiv__', ast.Mod: '__mod__', ast.Pow: '__pow__',
|
||||
ast.LShift: '__lshift__', ast.RShift: '__rshift__', ast.BitAnd: '__and__', ast.BitOr: '__or__', ast.BitXor: '__xor__',
|
||||
ast.Add: '__add__',
|
||||
ast.Sub: '__sub__',
|
||||
ast.Mult: '__mul__',
|
||||
ast.Div: '__truediv__',
|
||||
ast.FloorDiv: '__floordiv__',
|
||||
ast.Mod: '__mod__',
|
||||
ast.Pow: '__pow__',
|
||||
ast.LShift: '__lshift__',
|
||||
ast.RShift: '__rshift__',
|
||||
ast.BitAnd: '__and__',
|
||||
ast.BitOr: '__or__',
|
||||
ast.BitXor: '__xor__',
|
||||
}
|
||||
|
||||
def visit_then_else_blocks(self, node, liveins, then_block, else_block):
|
||||
@@ -508,7 +526,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if name in then_defs or name in else_defs:
|
||||
names.append(name)
|
||||
ret_types.append(then_defs[name].type if name in then_defs else else_defs[name].type)
|
||||
ir_ret_types.append(then_defs[name].handle.get_type() if name in then_defs else else_defs[name].handle.get_type())
|
||||
ir_ret_types.append(then_defs[name].handle.get_type() if name in
|
||||
then_defs else else_defs[name].handle.get_type())
|
||||
# variable defined in then but not in else
|
||||
if name in then_defs and name not in else_defs:
|
||||
else_defs[name] = liveins[name]
|
||||
@@ -602,8 +621,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
contains_return = ContainsReturnChecker(self.gscope).visit(node)
|
||||
if self.scf_stack and contains_return:
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node,
|
||||
"Cannot have `return` statements inside `while` or `for` statements in triton "
|
||||
None, node, "Cannot have `return` statements inside `while` or `for` statements in triton "
|
||||
"(note that this also applies to `return` statements that are inside functions "
|
||||
"transitively called from within `while`/`for` statements)")
|
||||
elif self.scf_stack or not contains_return:
|
||||
@@ -612,10 +630,13 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.visit_if_top_level(cond, node)
|
||||
else:
|
||||
cond = _unwrap_if_constexpr(cond)
|
||||
if type(cond) not in _condition_types: # not isinstance - we insist the real thing, no subclasses and no ducks
|
||||
# not isinstance - we insist the real thing, no subclasses and no ducks
|
||||
if type(cond) not in _condition_types:
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
|
||||
', '.join(_.__name__ for _ in _condition_types), type(cond).__name__))
|
||||
None, node,
|
||||
"`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
|
||||
', '.join(_.__name__ for _ in _condition_types),
|
||||
type(cond).__name__))
|
||||
if cond:
|
||||
self.visit_compound_statement(node.body)
|
||||
else:
|
||||
@@ -662,10 +683,14 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None
|
||||
else:
|
||||
cond = _unwrap_if_constexpr(cond)
|
||||
if type(cond) not in _condition_types: # not isinstance - we insist the real thing, no subclasses and no ducks
|
||||
|
||||
# not isinstance - we insist the real thing, no subclasses and no ducks
|
||||
if type(cond) not in _condition_types:
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
|
||||
', '.join(_.__name__ for _ in _condition_types), type(cond).__name__))
|
||||
None, node,
|
||||
"`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
|
||||
', '.join(_.__name__ for _ in _condition_types),
|
||||
type(cond).__name__))
|
||||
if cond:
|
||||
return self.visit(node.body)
|
||||
else:
|
||||
@@ -687,8 +712,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
return constexpr(lhs_value is not rhs_value)
|
||||
method_name = self._method_name_for_comp_op.get(type(node.ops[0]))
|
||||
if method_name is None:
|
||||
raise UnsupportedLanguageConstruct(None, node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__))
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__))
|
||||
return self._apply_binary_method(method_name, lhs, rhs)
|
||||
|
||||
_method_name_for_comp_op: Dict[Type[ast.cmpop], str] = {
|
||||
ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__'
|
||||
}
|
||||
@@ -697,11 +724,15 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
op = self.visit(node.operand)
|
||||
fn = self._method_name_for_unary_op.get(type(node.op))
|
||||
if fn is None:
|
||||
raise UnsupportedLanguageConstruct(None, node, "AST unary operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "AST unary operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
||||
if _is_triton_tensor(op):
|
||||
return getattr(op, fn)(_builder=self.builder)
|
||||
return getattr(op, fn)()
|
||||
_method_name_for_unary_op: Dict[Type[ast.unaryop], str] = {ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__'}
|
||||
|
||||
_method_name_for_unary_op: Dict[Type[ast.unaryop], str] = {
|
||||
ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__'
|
||||
}
|
||||
|
||||
def visit_While(self, node):
|
||||
with enter_sub_region(self) as sr:
|
||||
@@ -796,9 +827,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
iter_args = [self.visit(arg) for arg in node.iter.args]
|
||||
if IteratorClass == language.static_range:
|
||||
iterator = IteratorClass(*iter_args)
|
||||
static_range = range(iterator.start.value,
|
||||
iterator.end.value,
|
||||
iterator.step.value)
|
||||
static_range = range(iterator.start.value, iterator.end.value, iterator.step.value)
|
||||
for i in static_range:
|
||||
self.lscope[node.target.id] = constexpr(i)
|
||||
self.visit_compound_statement(node.body)
|
||||
@@ -935,8 +964,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def call_JitFunction(self, fn: JITFunction, args, kwargs):
|
||||
args = inspect.getcallargs(fn.fn, *args, **kwargs)
|
||||
args = [args[name] for name in fn.arg_names]
|
||||
args = [arg if _is_triton_tensor(arg)
|
||||
else constexpr(arg) for arg in args]
|
||||
args = [arg if _is_triton_tensor(arg) else constexpr(arg) for arg in args]
|
||||
# generate function def
|
||||
attributes = dict()
|
||||
constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)]
|
||||
@@ -954,8 +982,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
debug = self.debug if fn.debug is None else fn.debug
|
||||
file_name, begin_line = _get_fn_file_line(fn)
|
||||
generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module,
|
||||
function_name=fn_name, function_types=self.function_ret_types, debug=debug, noinline=fn.noinline,
|
||||
file_name=file_name, begin_line=begin_line, target=self.builder.target)
|
||||
function_name=fn_name, function_types=self.function_ret_types, debug=debug,
|
||||
noinline=fn.noinline, file_name=file_name, begin_line=begin_line,
|
||||
target=self.builder.target)
|
||||
generator.visit(fn.parse())
|
||||
callee_ret_type = generator.last_ret_type
|
||||
self.function_ret_types[fn_name] = callee_ret_type
|
||||
@@ -983,7 +1012,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
kws = dict(self.visit(keyword) for keyword in node.keywords)
|
||||
args = [self.visit(arg) for arg in node.args]
|
||||
if fn is language.core.device_assert: # TODO: this should not be so hardcoded
|
||||
if fn is language.core.device_assert: # TODO: this should not be so hardcoded
|
||||
if not self.debug:
|
||||
return
|
||||
if isinstance(fn, JITFunction):
|
||||
@@ -1004,16 +1033,21 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
def visit_BoolOp(self, node: ast.BoolOp):
|
||||
if len(node.values) != 2:
|
||||
raise UnsupportedLanguageConstruct(None, node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.")
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node,
|
||||
"chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.")
|
||||
lhs = self.visit(node.values[0])
|
||||
rhs = self.visit(node.values[1])
|
||||
method_name = self._method_name_for_bool_op.get(type(node.op))
|
||||
if method_name is None:
|
||||
raise UnsupportedLanguageConstruct(None, node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
||||
return self._apply_binary_method(method_name, lhs, rhs)
|
||||
|
||||
_method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'}
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
|
||||
def visit_NameConstant(self, node):
|
||||
return constexpr(node.value)
|
||||
|
||||
@@ -1046,7 +1080,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
evaluated = self.visit(value.value)
|
||||
if not _is_constexpr(evaluated):
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type " + str(type(evaluated)))
|
||||
None, node,
|
||||
"Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type "
|
||||
+ str(type(evaluated)))
|
||||
values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value)
|
||||
else:
|
||||
raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value)))
|
||||
@@ -1088,7 +1124,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
passed = _unwrap_if_constexpr(self.visit(node.args[0]))
|
||||
if not isinstance(passed, bool):
|
||||
raise NotImplementedError("Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values")
|
||||
raise NotImplementedError(
|
||||
"Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values"
|
||||
)
|
||||
if not passed:
|
||||
if arg_count == 1:
|
||||
message = ""
|
||||
@@ -1175,10 +1213,9 @@ def ast_to_ttir(fn, signature, specialization, constants, debug, target):
|
||||
file_name, begin_line = _get_fn_file_line(fn)
|
||||
|
||||
prototype = language.function_type([], arg_types)
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants,
|
||||
function_name=function_name, attributes=new_attrs,
|
||||
is_kernel=True, debug=debug, file_name=file_name, begin_line=begin_line,
|
||||
target=target)
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name,
|
||||
attributes=new_attrs, is_kernel=True, debug=debug, file_name=file_name,
|
||||
begin_line=begin_line, target=target)
|
||||
try:
|
||||
generator.visit(fn.parse())
|
||||
except CompilationError as e:
|
||||
|
||||
@@ -11,10 +11,8 @@ from typing import Any
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs,
|
||||
compile_ptx_to_cubin, get_env_vars, get_num_warps,
|
||||
get_shared_memory_size, ir, runtime,
|
||||
translate_llvmir_to_ptx,
|
||||
from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs, compile_ptx_to_cubin, get_env_vars,
|
||||
get_num_warps, get_shared_memory_size, ir, runtime, translate_llvmir_to_ptx,
|
||||
translate_triton_gpu_to_llvmir)
|
||||
from ..common.backend import get_backend, get_cuda_version_key, path_to_ptxas
|
||||
from ..common.build import is_hip
|
||||
@@ -23,13 +21,11 @@ from ..common.build import is_hip
|
||||
from ..runtime.autotuner import OutOfResources
|
||||
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
|
||||
from ..runtime.driver import driver
|
||||
from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device,
|
||||
get_device_capability)
|
||||
from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device, get_device_capability)
|
||||
from ..tools.disasm import get_sass
|
||||
from .code_generator import ast_to_ttir
|
||||
from .make_launcher import make_stub
|
||||
from .utils import (InfoFromBackendForTensorMap, TensorMapManager,
|
||||
get_ids_of_tensormaps, parse_tma_info)
|
||||
from .utils import (InfoFromBackendForTensorMap, TensorMapManager, get_ids_of_tensormaps, parse_tma_info)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -44,6 +40,7 @@ def _is_cuda(target):
|
||||
|
||||
|
||||
class LazyDict(dict):
|
||||
|
||||
def __getitem__(self, key):
|
||||
val = dict.__getitem__(self, key)
|
||||
if callable(val):
|
||||
@@ -94,8 +91,8 @@ def ttir_to_ttgir(mod, num_warps, num_ctas, target):
|
||||
return mod
|
||||
|
||||
|
||||
def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target,
|
||||
cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue):
|
||||
def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization,
|
||||
enable_persistent, optimize_epilogue):
|
||||
is_cuda = _is_cuda(target)
|
||||
if is_cuda:
|
||||
capability = target.capability
|
||||
@@ -173,6 +170,7 @@ def ttgir_to_llir(mod, extern_libs, target, tma_infos):
|
||||
|
||||
# PTX translation
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def ptx_get_version(cuda_version) -> int:
|
||||
'''
|
||||
@@ -253,7 +251,8 @@ def make_hash(fn, target, env_vars, device_backend, **kwargs):
|
||||
enable_persistent = kwargs.get("enable_persistent", False)
|
||||
debug = kwargs.get("debug", False)
|
||||
# Get unique key for the compiled code
|
||||
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1), sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8))
|
||||
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1),
|
||||
sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8))
|
||||
configs_key = [get_conf_key(conf) for conf in configs]
|
||||
env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())]
|
||||
key = f"{fn.cache_key}-{version_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}"
|
||||
@@ -299,12 +298,14 @@ else:
|
||||
|
||||
|
||||
def _get_jsonable_constants(constants):
|
||||
|
||||
def _is_jsonable(x):
|
||||
try:
|
||||
json.dumps(x)
|
||||
return True
|
||||
except (TypeError, OverflowError):
|
||||
return False
|
||||
|
||||
serialized_constants = {}
|
||||
for constant in constants:
|
||||
if _is_jsonable(constants[constant]):
|
||||
@@ -319,7 +320,9 @@ def parse_mlir_module(path, context):
|
||||
return module
|
||||
|
||||
|
||||
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"], defaults=[set(), set(), set(), set()])
|
||||
instance_descriptor = namedtuple("instance_descriptor",
|
||||
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"],
|
||||
defaults=[set(), set(), set(), set()])
|
||||
|
||||
|
||||
def get_cuda_capability(capability):
|
||||
@@ -355,10 +358,8 @@ def get_arch_default_num_stages(device_type, capability=None):
|
||||
|
||||
def add_cuda_stages(target, extern_libs, stages):
|
||||
|
||||
stages["ptx"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: llir_to_ptx(src, target))
|
||||
stages["cubin"] = (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ptx_to_cubin(src, target))
|
||||
stages["ptx"] = (lambda path: Path(path).read_text(), lambda src: llir_to_ptx(src, target))
|
||||
stages["cubin"] = (lambda path: Path(path).read_bytes(), lambda src: ptx_to_cubin(src, target))
|
||||
|
||||
|
||||
def compile(fn, **kwargs):
|
||||
@@ -401,7 +402,8 @@ def compile(fn, **kwargs):
|
||||
# build architecture descriptor
|
||||
if device_type == "cuda":
|
||||
_device_backend = get_backend(device_type)
|
||||
target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps, enable_fp_fusion=enable_fp_fusion)
|
||||
target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps,
|
||||
enable_fp_fusion=enable_fp_fusion)
|
||||
else:
|
||||
_device_backend = get_backend(device_type)
|
||||
assert _device_backend
|
||||
@@ -409,11 +411,12 @@ def compile(fn, **kwargs):
|
||||
# build compilation stages
|
||||
stages = dict()
|
||||
stages["ast"] = (lambda path: fn, None)
|
||||
stages["ttir"] = (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target))
|
||||
stages["ttir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttir(
|
||||
ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target))
|
||||
if is_cuda:
|
||||
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
|
||||
stages["ttgir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttgir(
|
||||
ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info,
|
||||
enable_warp_specialization, enable_persistent, optimize_epilogue))
|
||||
stages["llir"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, target, tma_infos))
|
||||
add_cuda_stages(target, extern_libs, stages)
|
||||
@@ -451,7 +454,8 @@ def compile(fn, **kwargs):
|
||||
if ir_name == 'ttgir':
|
||||
num_warps_matches = re.findall(ttgir_num_warps_pattern, src)
|
||||
assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps"
|
||||
assert "num_warps" not in kwargs or int(num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile"
|
||||
assert "num_warps" not in kwargs or int(
|
||||
num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile"
|
||||
num_warps = int(num_warps_matches[0])
|
||||
param_tys = [convert_type_repr(ty) for ty in types]
|
||||
signature = {k: v for k, v in enumerate(param_tys)}
|
||||
@@ -461,8 +465,10 @@ def compile(fn, **kwargs):
|
||||
fn_cache_manager = get_cache_manager(make_hash(fn, target, get_env_vars(), _device_backend, **kwargs))
|
||||
# managers used to dump and override IR for debugging
|
||||
enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1"
|
||||
fn_override_manager = get_override_manager(make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True))
|
||||
fn_dump_manager = get_dump_manager(make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True))
|
||||
fn_override_manager = get_override_manager(
|
||||
make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True))
|
||||
fn_dump_manager = get_dump_manager(
|
||||
make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True))
|
||||
|
||||
# determine name and extension type of provided function
|
||||
if isinstance(fn, JITFunction):
|
||||
@@ -475,9 +481,7 @@ def compile(fn, **kwargs):
|
||||
metadata_filename = f"{name}.json"
|
||||
|
||||
# The group is addressed by the metadata
|
||||
metadata_group = fn_cache_manager.get_group(
|
||||
metadata_filename
|
||||
) or {}
|
||||
metadata_group = fn_cache_manager.get_group(metadata_filename) or {}
|
||||
|
||||
metadata_path = metadata_group.get(metadata_filename)
|
||||
|
||||
@@ -485,17 +489,18 @@ def compile(fn, **kwargs):
|
||||
with open(metadata_path) as f:
|
||||
metadata = json.load(f)
|
||||
if 'tensormaps_info' in metadata:
|
||||
metadata['tensormaps_info'] = [
|
||||
InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']]
|
||||
metadata['tensormaps_info'] = [InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']]
|
||||
else:
|
||||
metadata = {"num_warps": num_warps,
|
||||
"num_ctas": num_ctas,
|
||||
"num_stages": num_stages,
|
||||
"enable_warp_specialization": enable_warp_specialization,
|
||||
"enable_persistent": enable_persistent,
|
||||
"constants": _get_jsonable_constants(constants),
|
||||
"debug": debug,
|
||||
"target": target, }
|
||||
metadata = {
|
||||
"num_warps": num_warps,
|
||||
"num_ctas": num_ctas,
|
||||
"num_stages": num_stages,
|
||||
"enable_warp_specialization": enable_warp_specialization,
|
||||
"enable_persistent": enable_persistent,
|
||||
"constants": _get_jsonable_constants(constants),
|
||||
"debug": debug,
|
||||
"target": target,
|
||||
}
|
||||
metadata.update(get_env_vars())
|
||||
if ext == "ptx":
|
||||
assert "shared" in kwargs, "ptx compilation must provide shared memory size"
|
||||
@@ -567,10 +572,7 @@ def compile(fn, **kwargs):
|
||||
|
||||
ids_of_folded_args = tuple([int(k) for k in configs[0].ids_of_folded_args]) if isinstance(fn, JITFunction) else ()
|
||||
if "clusterDims" not in metadata:
|
||||
metadata["clusterDims"] = [
|
||||
cluster_info.clusterDimX,
|
||||
cluster_info.clusterDimY,
|
||||
cluster_info.clusterDimZ]
|
||||
metadata["clusterDims"] = [cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ]
|
||||
|
||||
if len(tma_infos) > 0:
|
||||
metadata["tensormaps_info"] = parse_tma_info(tma_infos, ids_of_folded_args)
|
||||
@@ -584,7 +586,10 @@ def compile(fn, **kwargs):
|
||||
fn.tensormaps_info = metadata["tensormaps_info"]
|
||||
|
||||
ids_of_const_exprs = tuple(fn.constexprs) if isinstance(fn, JITFunction) else ()
|
||||
ids = {"ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs": ids_of_const_exprs}
|
||||
ids = {
|
||||
"ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs":
|
||||
ids_of_const_exprs
|
||||
}
|
||||
# cache manager
|
||||
if is_cuda:
|
||||
so_path = make_stub(name, signature, constants, ids, enable_warp_specialization=enable_warp_specialization)
|
||||
@@ -592,7 +597,8 @@ def compile(fn, **kwargs):
|
||||
so_path = _device_backend.make_launcher_stub(name, signature, constants, ids)
|
||||
# write-back metadata, if it didn't come from the cache
|
||||
if metadata_path is None:
|
||||
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, binary=False)
|
||||
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename,
|
||||
binary=False)
|
||||
fn_cache_manager.put_group(metadata_filename, metadata_group)
|
||||
|
||||
# return handle to compiled kernel
|
||||
@@ -640,10 +646,7 @@ class CompiledKernel:
|
||||
|
||||
if self.device_type in ["cuda"]:
|
||||
device = get_current_device()
|
||||
bin_path = {
|
||||
driver.HIP: "hsaco_path",
|
||||
driver.CUDA: "cubin"
|
||||
}[driver.backend]
|
||||
bin_path = {driver.HIP: "hsaco_path", driver.CUDA: "cubin"}[driver.backend]
|
||||
max_shared = driver.utils.get_device_properties(device)["max_shared_mem"]
|
||||
fn_load_binary = driver.utils.load_binary
|
||||
else:
|
||||
@@ -691,4 +694,5 @@ class CompiledKernel:
|
||||
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.num_ctas, self.clusterDims[0],
|
||||
self.clusterDims[1], self.clusterDims[2], self.shared, stream, self.cu_function,
|
||||
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand)
|
||||
|
||||
return runner
|
||||
|
||||
@@ -40,6 +40,7 @@ def make_stub(name, signature, constants, ids, **kwargs):
|
||||
else:
|
||||
return cache_path
|
||||
|
||||
|
||||
# ----- source code generation --------
|
||||
|
||||
|
||||
@@ -100,7 +101,10 @@ def generate_launcher(constants, signature, ids):
|
||||
|
||||
# generate glue code
|
||||
folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']]
|
||||
params = [i for i in signature.keys() if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs)]
|
||||
params = [
|
||||
i for i in signature.keys()
|
||||
if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs)
|
||||
]
|
||||
src = f"""
|
||||
#include \"cuda.h\"
|
||||
#include <stdbool.h>
|
||||
|
||||
@@ -158,19 +158,21 @@ class InfoFromBackendForTensorMap:
|
||||
|
||||
# dtype:cuda.CUtensorMapDataType | int
|
||||
def bytes_from_type(self, dtype):
|
||||
return {driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT8"]: 1,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT16"]: 2,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT64"]: 8,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT64"]: 8,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT16"]: 2,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT64"]: 8,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_BFLOAT16"]: 2,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ"]: 4}[dtype]
|
||||
return {
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT8"]: 1,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT16"]: 2,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT64"]: 8,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT64"]: 8,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT16"]: 2,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT64"]: 8,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_BFLOAT16"]: 2,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ"]: 4
|
||||
}[dtype]
|
||||
|
||||
def getTensorMapDataType(self):
|
||||
return self.tensorDataType
|
||||
@@ -259,22 +261,29 @@ class InfoFromBackendForTensorMap:
|
||||
self.getInterleave(),
|
||||
self.getSwizzle(),
|
||||
self.getL2Promotion(),
|
||||
self.getOobFill()
|
||||
self.getOobFill(),
|
||||
)
|
||||
|
||||
# make hashable to use as partial key in cache
|
||||
def __hash__(self):
|
||||
return hash((self.ids_of_folded_args, self.globalAddressArgIdx, tuple(self.globalDimsArgIdx), tuple(self.globalStridesArgIdx), self.tensorDataType,
|
||||
self.tensorRank, tuple(self.boxDims), tuple(self.elementStrides), self.interleave, self.swizzle, self.l2Promotion, self.oobFill))
|
||||
return hash((self.ids_of_folded_args, self.globalAddressArgIdx, tuple(self.globalDimsArgIdx),
|
||||
tuple(self.globalStridesArgIdx), self.tensorDataType, self.tensorRank, tuple(self.boxDims),
|
||||
tuple(self.elementStrides), self.interleave, self.swizzle, self.l2Promotion, self.oobFill))
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
return (self.ids_of_folded_args, self.globalAddressArgIdx, self.globalDimsArgIdx, self.globalStridesArgIdx, self.tensorDataType, self.tensorRank, self.boxDims, self.elementStrides, self.interleave, self.swizzle, self.l2Promotion, self.oobFill) == (
|
||||
other.ids_of_folded_args, other.globalAddressArgIdx, other.globalDimsArgIdx, other.globalStridesArgIdx, other.tensorDataType, other.tensorRank, other.boxDims, other.elementStrides, other.interleave, other.swizzle, other.l2Promotion, other.oobFill)
|
||||
return (self.ids_of_folded_args, self.globalAddressArgIdx, self.globalDimsArgIdx, self.globalStridesArgIdx,
|
||||
self.tensorDataType, self.tensorRank, self.boxDims, self.elementStrides, self.interleave, self.swizzle,
|
||||
self.l2Promotion,
|
||||
self.oobFill) == (other.ids_of_folded_args, other.globalAddressArgIdx, other.globalDimsArgIdx,
|
||||
other.globalStridesArgIdx, other.tensorDataType, other.tensorRank, other.boxDims,
|
||||
other.elementStrides, other.interleave, other.swizzle, other.l2Promotion,
|
||||
other.oobFill)
|
||||
|
||||
|
||||
class TensorMapManager:
|
||||
|
||||
def __init__(self):
|
||||
self.tensormaps_device = {}
|
||||
|
||||
@@ -286,8 +295,7 @@ class TensorMapManager:
|
||||
t_tensormap = e.tensormap(args)
|
||||
TENSORMAP_SIZE_IN_BYTES = 128
|
||||
t_tensormap_device = driver.utils.cuMemAlloc(TENSORMAP_SIZE_IN_BYTES)
|
||||
driver.utils.cuMemcpyHtoD(
|
||||
t_tensormap_device, t_tensormap, TENSORMAP_SIZE_IN_BYTES)
|
||||
driver.utils.cuMemcpyHtoD(t_tensormap_device, t_tensormap, TENSORMAP_SIZE_IN_BYTES)
|
||||
self.tensormaps_device[key] = t_tensormap_device
|
||||
return int(self.tensormaps_device[key])
|
||||
|
||||
|
||||
@@ -109,7 +109,6 @@ from .random import (
|
||||
uint32_to_uniform_float,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TRITON_MAX_TENSOR_NUMEL",
|
||||
"abs",
|
||||
|
||||
@@ -22,10 +22,8 @@ def builtin(fn: T) -> T:
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if "_builder" not in kwargs or kwargs["_builder"] is None:
|
||||
raise ValueError(
|
||||
"Did you forget to add @triton.jit ? "
|
||||
"(`_builder` argument must be provided outside of JIT functions.)"
|
||||
)
|
||||
raise ValueError("Did you forget to add @triton.jit ? "
|
||||
"(`_builder` argument must be provided outside of JIT functions.)")
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
setattr(wrapper, TRITON_BUILTIN, True)
|
||||
@@ -54,7 +52,7 @@ def _to_tensor(x, builder):
|
||||
else:
|
||||
raise RuntimeError(f'Nonrepresentable integer {x}.')
|
||||
elif isinstance(x, float):
|
||||
min_float32 = 2 ** -126
|
||||
min_float32 = 2**-126
|
||||
max_float32 = (2 - 2**-23) * 2**127
|
||||
abs_x = __builtins__['abs'](x)
|
||||
if abs_x == float("inf") or\
|
||||
@@ -229,7 +227,7 @@ class dtype:
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.name,))
|
||||
return hash((self.name, ))
|
||||
|
||||
@property
|
||||
def scalar(self):
|
||||
@@ -279,6 +277,7 @@ class dtype:
|
||||
|
||||
|
||||
class pointer_type(dtype):
|
||||
|
||||
def __init__(self, element_ty: dtype, address_space: int = 1):
|
||||
if not isinstance(element_ty, dtype):
|
||||
raise TypeError('element_ty is a {type(element_ty).__name__}.')
|
||||
@@ -313,6 +312,7 @@ class pointer_type(dtype):
|
||||
|
||||
|
||||
class block_type(dtype):
|
||||
|
||||
def __init__(self, element_ty: dtype, shape: List):
|
||||
self.element_ty = element_ty
|
||||
|
||||
@@ -363,6 +363,7 @@ class block_type(dtype):
|
||||
|
||||
|
||||
class function_type(dtype):
|
||||
|
||||
def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None:
|
||||
self.ret_types = ret_types
|
||||
self.param_types = param_types
|
||||
@@ -511,7 +512,7 @@ class constexpr:
|
||||
return constexpr(~self.value)
|
||||
|
||||
def __pow__(self, other):
|
||||
return constexpr(self.value ** other.value)
|
||||
return constexpr(self.value**other.value)
|
||||
|
||||
def __rshift__(self, other):
|
||||
return constexpr(self.value >> other.value)
|
||||
@@ -527,6 +528,7 @@ class constexpr:
|
||||
|
||||
|
||||
class tensor:
|
||||
|
||||
def __init__(self, handle, type: dtype):
|
||||
# IR handle
|
||||
self.handle = handle
|
||||
@@ -993,6 +995,7 @@ def expand_dims(input, axis, _builder=None):
|
||||
ret = semantic.expand_dims(ret, a, _builder)
|
||||
return ret
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Linear Algebra
|
||||
# -----------------------
|
||||
@@ -1141,6 +1144,7 @@ def advance(base: tensor, offsets, _builder=None):
|
||||
"""
|
||||
return semantic.advance(base, offsets, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Atomic Memory Operations
|
||||
# -----------------------
|
||||
@@ -1253,6 +1257,7 @@ def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
# Conditioning
|
||||
# -----------------------
|
||||
|
||||
|
||||
@builtin
|
||||
def where(condition, x, y, _builder=None):
|
||||
"""
|
||||
@@ -1280,6 +1285,7 @@ def where(condition, x, y, _builder=None):
|
||||
# Math
|
||||
# -----------------------
|
||||
|
||||
|
||||
@builtin
|
||||
def umulhi(x, y, _builder=None):
|
||||
"""
|
||||
@@ -1373,6 +1379,7 @@ def abs(x, _builder=None):
|
||||
# Reductions
|
||||
# -----------------------
|
||||
|
||||
|
||||
def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]:
|
||||
|
||||
def _decorator(func: T) -> T:
|
||||
@@ -1411,8 +1418,7 @@ def reduce(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
|
||||
"""
|
||||
if isinstance(input, tensor):
|
||||
return reduce((input,), axis, combine_fn,
|
||||
_builder=_builder, _generator=_generator)[0]
|
||||
return reduce((input, ), axis, combine_fn, _builder=_builder, _generator=_generator)[0]
|
||||
|
||||
def make_combine_region(reduce_op):
|
||||
in_scalar_tys = [t.type.scalar for t in input]
|
||||
@@ -1422,14 +1428,14 @@ def reduce(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
with _insertion_guard(_builder):
|
||||
param_types = [ty.to_ir(_builder) for ty in prototype.param_types]
|
||||
block = _builder.create_block_with_parent(region, param_types)
|
||||
args = [tensor(block.arg(i), ty)
|
||||
for i, ty in enumerate(prototype.param_types)]
|
||||
args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)]
|
||||
results = _generator.call_JitFunction(combine_fn, args, kwargs={})
|
||||
if isinstance(results, tensor):
|
||||
handles = [results.handle]
|
||||
else:
|
||||
handles = [r.handle for r in results]
|
||||
_builder.create_reduce_ret(*handles)
|
||||
|
||||
if axis is not None:
|
||||
axis = _constexpr_to_value(axis)
|
||||
return semantic.reduction(input, axis, make_combine_region, _builder)
|
||||
@@ -1459,8 +1465,7 @@ def _reduce_with_indices(input, axis, combine_fn, _builder=None, _generator=None
|
||||
index = expand_dims(index, axes_to_expand, _builder=_builder)
|
||||
index = broadcast_to(index, input.shape, _builder=_builder)
|
||||
|
||||
rvalue, rindices = reduce((input, index), axis, combine_fn,
|
||||
_builder=_builder, _generator=_generator)
|
||||
rvalue, rindices = reduce((input, index), axis, combine_fn, _builder=_builder, _generator=_generator)
|
||||
return rvalue, rindices
|
||||
|
||||
|
||||
@@ -1468,6 +1473,7 @@ def _reduce_with_indices(input, axis, combine_fn, _builder=None, _generator=None
|
||||
# Scans
|
||||
# -----------------------
|
||||
|
||||
|
||||
def _add_scan_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]:
|
||||
|
||||
def _decorator(func: T) -> T:
|
||||
@@ -1492,8 +1498,7 @@ def associative_scan(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
|
||||
"""
|
||||
if isinstance(input, tensor):
|
||||
return associative_scan((input,), axis, combine_fn,
|
||||
_builder=_builder, _generator=_generator)[0]
|
||||
return associative_scan((input, ), axis, combine_fn, _builder=_builder, _generator=_generator)[0]
|
||||
|
||||
def make_combine_region(scan_op):
|
||||
in_scalar_tys = [t.type.scalar for t in input]
|
||||
@@ -1503,17 +1508,18 @@ def associative_scan(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
with _insertion_guard(_builder):
|
||||
param_types = [ty.to_ir(_builder) for ty in prototype.param_types]
|
||||
block = _builder.create_block_with_parent(region, param_types)
|
||||
args = [tensor(block.arg(i), ty)
|
||||
for i, ty in enumerate(prototype.param_types)]
|
||||
args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)]
|
||||
results = _generator.call_JitFunction(combine_fn, args, kwargs={})
|
||||
if isinstance(results, tensor):
|
||||
handles = [results.handle]
|
||||
else:
|
||||
handles = [r.handle for r in results]
|
||||
_builder.create_scan_ret(*handles)
|
||||
|
||||
axis = _constexpr_to_value(axis)
|
||||
return semantic.associative_scan(input, axis, make_combine_region, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Compiler Hint Ops
|
||||
# -----------------------
|
||||
@@ -1576,6 +1582,8 @@ def max_constancy(input, values, _builder=None):
|
||||
raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
|
||||
values = [x.value for x in values]
|
||||
return semantic.max_constancy(input, values)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Debugging functions
|
||||
# -----------------------
|
||||
@@ -1715,12 +1723,12 @@ def inline_asm_elementwise(asm: str, constraints: str, args: list, dtype, is_pur
|
||||
broadcast_arg = dispatch_args[0]
|
||||
# Get the broadcast shape over all the arguments
|
||||
for i, item in enumerate(dispatch_args):
|
||||
_, broadcast_arg = semantic.binary_op_type_checking_impl(
|
||||
item, broadcast_arg, _builder, arithmetic_check=False)
|
||||
_, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder,
|
||||
arithmetic_check=False)
|
||||
# Change the shape of each argument based on the broadcast shape
|
||||
for i in range(len(dispatch_args)):
|
||||
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
|
||||
dispatch_args[i], broadcast_arg, _builder, arithmetic_check=False)
|
||||
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder,
|
||||
arithmetic_check=False)
|
||||
ret_shape = broadcast_arg.shape
|
||||
res_ty = block_type(dtype, ret_shape)
|
||||
call = _builder.create_inline_asm(asm, constraints, [t.handle for t in args], res_ty.to_ir(_builder), is_pure, pack)
|
||||
@@ -1733,7 +1741,6 @@ def inline_asm_elementwise(asm: str, constraints: str, args: list, dtype, is_pur
|
||||
|
||||
|
||||
class static_range:
|
||||
|
||||
"""
|
||||
Iterator that counts upward forever.
|
||||
|
||||
@@ -1777,7 +1784,9 @@ class static_range:
|
||||
# Extern functions
|
||||
# -----------------------
|
||||
|
||||
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, is_pure: bool, _builder=None):
|
||||
|
||||
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple,
|
||||
is_pure: bool, _builder=None):
|
||||
'''
|
||||
Dispatch a function to a library
|
||||
:param func: the function to dispatch
|
||||
@@ -1819,7 +1828,8 @@ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dic
|
||||
return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder), is_pure), ret_type)
|
||||
|
||||
|
||||
def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, _builder=None):
|
||||
def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool,
|
||||
_builder=None):
|
||||
'''
|
||||
Dispatch an elementwise function to a library
|
||||
:param lib_name: the name of the library
|
||||
@@ -1848,12 +1858,12 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol
|
||||
broadcast_arg = dispatch_args[0]
|
||||
# Get the broadcast shape over all the arguments
|
||||
for i, item in enumerate(dispatch_args):
|
||||
_, broadcast_arg = semantic.binary_op_type_checking_impl(
|
||||
item, broadcast_arg, _builder, arithmetic_check=arithmetic_check)
|
||||
_, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder,
|
||||
arithmetic_check=arithmetic_check)
|
||||
# Change the shape of each argument based on the broadcast shape
|
||||
for i in range(len(dispatch_args)):
|
||||
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
|
||||
dispatch_args[i], broadcast_arg, _builder, arithmetic_check=arithmetic_check)
|
||||
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder,
|
||||
arithmetic_check=arithmetic_check)
|
||||
if not all_scalar:
|
||||
ret_shape = broadcast_arg.shape
|
||||
func = getattr(_builder, "create_extern_elementwise")
|
||||
|
||||
@@ -3,16 +3,14 @@ from .. import core
|
||||
|
||||
@core.extern
|
||||
def globaltimer(_builder=None):
|
||||
return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [],
|
||||
dtype=core.int64, is_pure=False,
|
||||
pack=1, _builder=_builder)
|
||||
return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], dtype=core.int64, is_pure=False, pack=1,
|
||||
_builder=_builder)
|
||||
|
||||
|
||||
@core.extern
|
||||
def smid(_builder=None):
|
||||
return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [],
|
||||
dtype=core.int32, is_pure=True,
|
||||
pack=1, _builder=_builder)
|
||||
return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1,
|
||||
_builder=_builder)
|
||||
|
||||
|
||||
@core.builtin
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -91,6 +91,7 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
# two_to_the_minus_32: tl.constexpr = 2.328306e-10
|
||||
# return x * two_to_the_minus_32
|
||||
|
||||
|
||||
@jit
|
||||
def uint32_to_uniform_float(x):
|
||||
"""
|
||||
@@ -134,6 +135,7 @@ def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
u4 = uint32_to_uniform_float(i4)
|
||||
return u1, u2, u3, u4
|
||||
|
||||
|
||||
# -------------------
|
||||
# randn
|
||||
# -------------------
|
||||
|
||||
@@ -16,10 +16,12 @@ def _is_cuda(target):
|
||||
from ..compiler.compiler import CudaTargetDescriptor
|
||||
return isinstance(target, CudaTargetDescriptor)
|
||||
|
||||
|
||||
# Create custom exception that prints message "hello"
|
||||
|
||||
|
||||
class IncompatibleTypeErrorImpl(Exception):
|
||||
|
||||
def __init__(self, type_a, type_b):
|
||||
self.type_a = type_a
|
||||
self.type_b = type_b
|
||||
@@ -31,6 +33,7 @@ class IncompatibleTypeErrorImpl(Exception):
|
||||
# Programming Model
|
||||
# ===----------------------------------------------------------------------===##
|
||||
|
||||
|
||||
def program_id(axis: int, builder: ir.builder) -> tl.tensor:
|
||||
if axis not in (0, 1, 2):
|
||||
raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}")
|
||||
@@ -42,6 +45,7 @@ def num_programs(axis: int, builder: ir.builder) -> tl.tensor:
|
||||
raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}")
|
||||
return tl.tensor(builder.create_get_num_programs(axis), tl.int32)
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
# Implicit Casting Utilities
|
||||
# ===----------------------------------------------------------------------===//
|
||||
@@ -92,10 +96,12 @@ def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> t
|
||||
# 5 ) both operands are integer and undergo
|
||||
# integer promotion
|
||||
if div_or_mod and a_ty.int_signedness != b_ty.int_signedness:
|
||||
raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + " because they have different signedness;"
|
||||
raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() +
|
||||
" because they have different signedness;"
|
||||
"this is unlikely to result in a useful answer. Cast them to the same signedness.")
|
||||
return integer_promote_impl(a_ty, b_ty)
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
# Binary Operators
|
||||
# ===----------------------------------------------------------------------===//
|
||||
@@ -113,12 +119,9 @@ def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -
|
||||
raise IncompatibleTypeErrorImpl(type_a, type_b)
|
||||
|
||||
|
||||
def binary_op_type_checking_impl(lhs: tl.tensor,
|
||||
rhs: tl.tensor,
|
||||
builder: ir.builder,
|
||||
allow_lhs_ptr=False, allow_rhs_ptr=False,
|
||||
arithmetic_check=True, div_or_mod=False
|
||||
) -> Tuple[tl.tensor, tl.tensor]:
|
||||
def binary_op_type_checking_impl(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder, allow_lhs_ptr=False,
|
||||
allow_rhs_ptr=False, arithmetic_check=True,
|
||||
div_or_mod=False) -> Tuple[tl.tensor, tl.tensor]:
|
||||
# implicit broadcasting
|
||||
lhs, rhs = broadcast_impl_value(lhs, rhs, builder)
|
||||
# implicit typecasting
|
||||
@@ -133,9 +136,7 @@ def binary_op_type_checking_impl(lhs: tl.tensor,
|
||||
return lhs, rhs
|
||||
|
||||
|
||||
def add(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def add(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = binary_op_type_checking_impl(input, other, builder, True, True)
|
||||
input_scalar_ty = input.type.scalar
|
||||
other_scalar_ty = other.type.scalar
|
||||
@@ -159,15 +160,12 @@ def add(input: tl.tensor,
|
||||
assert False
|
||||
|
||||
|
||||
def sub(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def sub(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = binary_op_type_checking_impl(input, other, builder, True, False)
|
||||
scalar_ty = input.type.scalar
|
||||
# ptr - offset
|
||||
if scalar_ty.is_ptr():
|
||||
return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle),
|
||||
input.type)
|
||||
return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle), input.type)
|
||||
# float - float
|
||||
if scalar_ty.is_floating():
|
||||
return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type)
|
||||
@@ -177,9 +175,7 @@ def sub(input: tl.tensor,
|
||||
assert False
|
||||
|
||||
|
||||
def mul(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def mul(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = binary_op_type_checking_impl(input, other, builder)
|
||||
scalar_ty = input.type.scalar
|
||||
# float * float
|
||||
@@ -191,9 +187,7 @@ def mul(input: tl.tensor,
|
||||
assert False
|
||||
|
||||
|
||||
def truediv(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def truediv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
|
||||
input_scalar_ty = input.type.scalar
|
||||
other_scalar_ty = other.type.scalar
|
||||
@@ -219,9 +213,7 @@ def truediv(input: tl.tensor,
|
||||
return tl.tensor(builder.create_fdiv(input.handle, other.handle), input.type)
|
||||
|
||||
|
||||
def floordiv(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def floordiv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
|
||||
input_scalar_ty = input.type.scalar
|
||||
other_scalar_ty = other.type.scalar
|
||||
@@ -236,10 +228,7 @@ def floordiv(input: tl.tensor,
|
||||
assert False
|
||||
|
||||
|
||||
def fdiv(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
ieee_rounding: bool,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def fdiv(input: tl.tensor, other: tl.tensor, ieee_rounding: bool, builder: ir.builder) -> tl.tensor:
|
||||
input_scalar_ty = input.type.scalar
|
||||
other_scalar_ty = other.type.scalar
|
||||
if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating():
|
||||
@@ -249,18 +238,14 @@ def fdiv(input: tl.tensor,
|
||||
return tl.tensor(ret, input.type)
|
||||
|
||||
|
||||
def mod(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def mod(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
|
||||
scalar_ty = input.type.scalar
|
||||
other_scalar_ty = other.type.scalar
|
||||
# float % float
|
||||
if scalar_ty.is_floating():
|
||||
# input - input.div(other, rounding_mode="floor") * other
|
||||
ret = sub(input, mul(floor(fdiv(input, other, False, builder), builder),
|
||||
other, builder),
|
||||
builder)
|
||||
ret = sub(input, mul(floor(fdiv(input, other, False, builder), builder), other, builder), builder)
|
||||
return ret
|
||||
# % int
|
||||
elif scalar_ty.is_int():
|
||||
@@ -274,13 +259,13 @@ def mod(input: tl.tensor,
|
||||
return tl.tensor(builder.create_urem(input.handle, other.handle), input.type)
|
||||
assert False
|
||||
|
||||
|
||||
##############
|
||||
# bitwise ops
|
||||
##############
|
||||
|
||||
|
||||
def bitwise_op_type_checking_impl(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
def bitwise_op_type_checking_impl(input: tl.tensor, other: tl.tensor,
|
||||
builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]:
|
||||
input, other = binary_op_type_checking_impl(input, other, builder, False, False, False)
|
||||
input_sca_ty = input.type.scalar
|
||||
@@ -295,23 +280,17 @@ def bitwise_op_type_checking_impl(input: tl.tensor,
|
||||
return input, other
|
||||
|
||||
|
||||
def and_(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def and_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = bitwise_op_type_checking_impl(input, other, builder)
|
||||
return tl.tensor(builder.create_and(input.handle, other.handle), input.type)
|
||||
|
||||
|
||||
def or_(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def or_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = bitwise_op_type_checking_impl(input, other, builder)
|
||||
return tl.tensor(builder.create_or(input.handle, other.handle), input.type)
|
||||
|
||||
|
||||
def xor_(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def xor_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = bitwise_op_type_checking_impl(input, other, builder)
|
||||
return tl.tensor(builder.create_xor(input.handle, other.handle), input.type)
|
||||
|
||||
@@ -338,26 +317,21 @@ def not_(input: tl.tensor, builder: ir.builder):
|
||||
return invert(input, builder)
|
||||
|
||||
|
||||
def lshr(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def lshr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = bitwise_op_type_checking_impl(input, other, builder)
|
||||
return tl.tensor(builder.create_lshr(input.handle, other.handle), input.type)
|
||||
|
||||
|
||||
def ashr(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def ashr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = bitwise_op_type_checking_impl(input, other, builder)
|
||||
return tl.tensor(builder.create_ashr(input.handle, other.handle), input.type)
|
||||
|
||||
|
||||
def shl(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def shl(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = bitwise_op_type_checking_impl(input, other, builder)
|
||||
return tl.tensor(builder.create_shl(input.handle, other.handle), input.type)
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
# Unary Operators
|
||||
# ===----------------------------------------------------------------------===//
|
||||
@@ -367,8 +341,7 @@ def plus(input: tl.tensor) -> tl.tensor:
|
||||
return input
|
||||
|
||||
|
||||
def minus(input: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def minus(input: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input_sca_ty = input.type.scalar
|
||||
if input_sca_ty.is_ptr():
|
||||
raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")")
|
||||
@@ -376,8 +349,7 @@ def minus(input: tl.tensor,
|
||||
return sub(_0, input, builder)
|
||||
|
||||
|
||||
def invert(input: tl.tensor,
|
||||
builder: tl.tensor) -> tl.tensor:
|
||||
def invert(input: tl.tensor, builder: tl.tensor) -> tl.tensor:
|
||||
input_sca_ty = input.type.scalar
|
||||
if input_sca_ty.is_ptr() or input_sca_ty.is_floating():
|
||||
raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")")
|
||||
@@ -395,9 +367,7 @@ def _bool_like(v: tl.tensor) -> tl.block_type:
|
||||
return tl.block_type(tl.int1, shape)
|
||||
|
||||
|
||||
def greater_than(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def greater_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = binary_op_type_checking_impl(input, other, builder)
|
||||
scalar_ty = input.type.scalar
|
||||
# float > float
|
||||
@@ -412,9 +382,7 @@ def greater_than(input: tl.tensor,
|
||||
assert False
|
||||
|
||||
|
||||
def greater_equal(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def greater_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = binary_op_type_checking_impl(input, other, builder)
|
||||
scalar_ty = input.type.scalar
|
||||
# float >= float
|
||||
@@ -429,9 +397,7 @@ def greater_equal(input: tl.tensor,
|
||||
assert False
|
||||
|
||||
|
||||
def less_than(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def less_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = binary_op_type_checking_impl(input, other, builder)
|
||||
scalar_ty = input.type.scalar
|
||||
# float < float
|
||||
@@ -446,9 +412,7 @@ def less_than(input: tl.tensor,
|
||||
assert False
|
||||
|
||||
|
||||
def less_equal(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def less_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = binary_op_type_checking_impl(input, other, builder)
|
||||
scalar_ty = input.type.scalar
|
||||
# float < float
|
||||
@@ -463,9 +427,7 @@ def less_equal(input: tl.tensor,
|
||||
assert False
|
||||
|
||||
|
||||
def equal(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = binary_op_type_checking_impl(input, other, builder)
|
||||
scalar_ty = input.type.scalar
|
||||
# float == float
|
||||
@@ -477,9 +439,7 @@ def equal(input: tl.tensor,
|
||||
assert False
|
||||
|
||||
|
||||
def not_equal(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def not_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
input, other = binary_op_type_checking_impl(input, other, builder)
|
||||
scalar_ty = input.type.scalar
|
||||
# float == float
|
||||
@@ -490,6 +450,7 @@ def not_equal(input: tl.tensor,
|
||||
return tl.tensor(builder.create_icmpNE(input.handle, other.handle), _bool_like(input))
|
||||
assert False
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
# Block Creation
|
||||
# ===----------------------------------------------------------------------===//
|
||||
@@ -532,6 +493,7 @@ def full(shape: List[int], value, dtype: tl.dtype, builder: ir.builder) -> tl.te
|
||||
# Shape Manipulation
|
||||
# ===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor:
|
||||
assert not value.type.is_block(), "Cannot splat a block tensor"
|
||||
if len(shape) == 0:
|
||||
@@ -540,9 +502,7 @@ def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor:
|
||||
return tl.tensor(builder.create_splat(value.handle, shape), ret_ty)
|
||||
|
||||
|
||||
def view(input: tl.tensor,
|
||||
dst_shape: List[int],
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def view(input: tl.tensor, dst_shape: List[int], builder: ir.builder) -> tl.tensor:
|
||||
# TODO: disable when TritonToTritonGPU handles views properly
|
||||
|
||||
# assert len(input.shape) == len(dst_shape)
|
||||
@@ -555,9 +515,7 @@ def view(input: tl.tensor,
|
||||
return tl.tensor(builder.create_view(input.handle, dst_shape), ret_ty)
|
||||
|
||||
|
||||
def reshape(input: tl.tensor,
|
||||
dst_shape: List[int],
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def reshape(input: tl.tensor, dst_shape: List[int], builder: ir.builder) -> tl.tensor:
|
||||
raise ValueError("`reshape` is not supported yet. Please use `view` instead if applicable. "
|
||||
"Note that view may reorder elements in an implementation- and context- dependent way.")
|
||||
|
||||
@@ -587,9 +545,7 @@ def trans(input: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
return tl.tensor(builder.create_trans(input.handle), ret_type)
|
||||
|
||||
|
||||
def broadcast_impl_shape(input: tl.tensor,
|
||||
shape: List[int],
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def broadcast_impl_shape(input: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor:
|
||||
if not input.type.is_block():
|
||||
ret_ty = tl.block_type(input.type, shape)
|
||||
return tl.tensor(builder.create_splat(input.handle, shape), ret_ty)
|
||||
@@ -607,9 +563,7 @@ def broadcast_impl_shape(input: tl.tensor,
|
||||
return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty)
|
||||
|
||||
|
||||
def broadcast_impl_value(lhs: tl.tensor,
|
||||
rhs: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def broadcast_impl_value(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
lhs_ty = lhs.type
|
||||
rhs_ty = rhs.type
|
||||
|
||||
@@ -629,13 +583,15 @@ def broadcast_impl_value(lhs: tl.tensor,
|
||||
if len(lhs_shape) < len(rhs_shape):
|
||||
# Add new axes to lhs
|
||||
for dim in range(len(lhs_shape), len(rhs_shape)):
|
||||
lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0), tl.block_type(lhs_ty.scalar, [1] + lhs_shape))
|
||||
lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0),
|
||||
tl.block_type(lhs_ty.scalar, [1] + lhs_shape))
|
||||
lhs_ty = lhs.type
|
||||
lhs_shape = lhs_ty.get_block_shapes()
|
||||
elif len(rhs_shape) < len(lhs_shape):
|
||||
# Add new axes to rhs
|
||||
for dim in range(len(rhs_shape), len(lhs_shape)):
|
||||
rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0), tl.block_type(rhs_ty.scalar, [1] + rhs_shape))
|
||||
rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0),
|
||||
tl.block_type(rhs_ty.scalar, [1] + rhs_shape))
|
||||
rhs_ty = rhs.type
|
||||
rhs_shape = rhs_ty.get_block_shapes()
|
||||
assert len(rhs_shape) == len(lhs_shape)
|
||||
@@ -661,14 +617,13 @@ def broadcast_impl_value(lhs: tl.tensor,
|
||||
# (scalar, scalar) => returns original blocks
|
||||
return lhs, rhs
|
||||
|
||||
|
||||
#######
|
||||
# cast
|
||||
#######
|
||||
|
||||
|
||||
def bitcast(input: tl.tensor,
|
||||
dst_ty: tl.dtype,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def bitcast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor:
|
||||
src_ty = input.type
|
||||
if src_ty.is_block():
|
||||
dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes())
|
||||
@@ -684,13 +639,10 @@ def bitcast(input: tl.tensor,
|
||||
if src_bits != dst_bits:
|
||||
raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to "
|
||||
"data-type of size " + str(dst_bits))
|
||||
return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||
|
||||
|
||||
def cast(input: tl.tensor,
|
||||
dst_ty: tl.dtype,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor:
|
||||
src_ty = input.type
|
||||
if isinstance(dst_ty, tl.constexpr):
|
||||
dst_ty = dst_ty.value
|
||||
@@ -709,8 +661,7 @@ def cast(input: tl.tensor,
|
||||
# Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64
|
||||
if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \
|
||||
(src_sca_ty.is_floating() and dst_sca_ty.is_fp8()):
|
||||
return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||
|
||||
# bf16 <=> (not fp32)
|
||||
if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \
|
||||
@@ -724,9 +675,7 @@ def cast(input: tl.tensor,
|
||||
dst_sca_ty.is_floating() and \
|
||||
src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth
|
||||
if truncate_fp:
|
||||
return tl.tensor(builder.create_fp_trunc(input.handle,
|
||||
dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||
|
||||
# Standard floating types' casting: extension
|
||||
# fp32 => fp64
|
||||
@@ -736,9 +685,7 @@ def cast(input: tl.tensor,
|
||||
dst_sca_ty.is_floating() and \
|
||||
src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth
|
||||
if ext_fp:
|
||||
return tl.tensor(builder.create_fp_ext(input.handle,
|
||||
dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||
|
||||
# Casting between integer types
|
||||
if src_sca_ty.is_int() and dst_sca_ty.is_int() and \
|
||||
@@ -749,9 +696,7 @@ def cast(input: tl.tensor,
|
||||
_0 = tl.tensor(builder.get_null_value(ty), input.dtype)
|
||||
return not_equal(input, _0, builder)
|
||||
else:
|
||||
return tl.tensor(builder.create_int_cast(input.handle,
|
||||
dst_ty.to_ir(builder), sign_extend),
|
||||
dst_ty)
|
||||
return tl.tensor(builder.create_int_cast(input.handle, dst_ty.to_ir(builder), sign_extend), dst_ty)
|
||||
|
||||
# Casting standard floating types to integer types
|
||||
if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int():
|
||||
@@ -760,35 +705,24 @@ def cast(input: tl.tensor,
|
||||
_0 = tl.tensor(builder.get_null_value(ty), input.dtype)
|
||||
return not_equal(input, _0, builder)
|
||||
elif dst_sca_ty.is_int_signed():
|
||||
return tl.tensor(builder.create_fp_to_si(input.handle,
|
||||
dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
return tl.tensor(builder.create_fp_to_si(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||
else:
|
||||
return tl.tensor(builder.create_fp_to_ui(input.handle,
|
||||
dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
return tl.tensor(builder.create_fp_to_ui(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||
|
||||
# Casting integer types to standard floating types
|
||||
if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating():
|
||||
if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed():
|
||||
return tl.tensor(builder.create_ui_to_fp(input.handle,
|
||||
dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
return tl.tensor(builder.create_ui_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||
else:
|
||||
return tl.tensor(builder.create_si_to_fp(input.handle,
|
||||
dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
return tl.tensor(builder.create_si_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||
|
||||
# Casting pointer types to integer types
|
||||
if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
|
||||
bitwidth = dst_sca_ty.int_bitwidth
|
||||
if bitwidth == 64:
|
||||
return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||
if bitwidth == 1:
|
||||
return not_equal(cast(input, tl.int64, builder),
|
||||
tl.tensor(builder.get_int64(0), tl.int64),
|
||||
builder)
|
||||
return not_equal(cast(input, tl.int64, builder), tl.tensor(builder.get_int64(0), tl.int64), builder)
|
||||
|
||||
# Casting integer types to pointer types
|
||||
if src_sca_ty.is_int() and dst_sca_ty.is_ptr():
|
||||
@@ -800,6 +734,7 @@ def cast(input: tl.tensor,
|
||||
|
||||
assert False, f'cannot cast {input} to {dst_ty}'
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
# Memory Operators
|
||||
# ===----------------------------------------------------------------------===//
|
||||
@@ -918,8 +853,8 @@ def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, evicti
|
||||
boundary_check = _canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes())
|
||||
|
||||
# Build IR
|
||||
return tl.tensor(builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction,
|
||||
is_volatile), dst_ty)
|
||||
return tl.tensor(
|
||||
builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile), dst_ty)
|
||||
|
||||
|
||||
def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder):
|
||||
@@ -975,19 +910,13 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_
|
||||
if not mask:
|
||||
return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty)
|
||||
else:
|
||||
return tl.tensor(builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache,
|
||||
eviction, is_volatile), dst_ty)
|
||||
return tl.tensor(
|
||||
builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction,
|
||||
is_volatile), dst_ty)
|
||||
|
||||
|
||||
def load(ptr: tl.tensor,
|
||||
mask: Optional[tl.tensor],
|
||||
other: Optional[tl.tensor],
|
||||
boundary_check,
|
||||
padding_option: str,
|
||||
cache_modifier: str,
|
||||
eviction_policy: str,
|
||||
is_volatile: bool,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], boundary_check, padding_option: str,
|
||||
cache_modifier: str, eviction_policy: str, is_volatile: bool, builder: ir.builder) -> tl.tensor:
|
||||
# Cache, eviction and padding options
|
||||
cache = _str_to_load_cache_modifier(cache_modifier)
|
||||
eviction = _str_to_eviction_policy(eviction_policy)
|
||||
@@ -1012,7 +941,8 @@ def _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builde
|
||||
if not val.type.is_block():
|
||||
val = broadcast_impl_shape(val, block_shape, builder)
|
||||
assert val.type.is_block(), "Value argument must be block type or a scalar"
|
||||
assert block_shape == val.type.get_block_shapes(), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch"
|
||||
assert block_shape == val.type.get_block_shapes(
|
||||
), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch"
|
||||
assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch"
|
||||
|
||||
elt_ty = ptr.type.element_ty.element_ty
|
||||
@@ -1070,13 +1000,8 @@ def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder):
|
||||
return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction), tl.void)
|
||||
|
||||
|
||||
def store(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: Optional[tl.tensor],
|
||||
boundary_check,
|
||||
cache_modifier: str,
|
||||
eviction_policy: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def store(ptr: tl.tensor, val: tl.tensor, mask: Optional[tl.tensor], boundary_check, cache_modifier: str,
|
||||
eviction_policy: str, builder: ir.builder) -> tl.tensor:
|
||||
# Cache and eviction options
|
||||
cache = _str_to_store_cache_modifier(cache_modifier)
|
||||
eviction = _str_to_eviction_policy(eviction_policy)
|
||||
@@ -1094,12 +1019,7 @@ def store(ptr: tl.tensor,
|
||||
#########
|
||||
|
||||
|
||||
def atomic_cas(ptr: tl.tensor,
|
||||
cmp: tl.tensor,
|
||||
val: tl.tensor,
|
||||
sem: str,
|
||||
scope: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def atomic_cas(ptr: tl.tensor, cmp: tl.tensor, val: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
|
||||
sem = _str_to_sem(sem)
|
||||
scope = _str_to_scope(scope)
|
||||
element_ty = ptr.type.scalar.element_ty
|
||||
@@ -1108,10 +1028,7 @@ def atomic_cas(ptr: tl.tensor,
|
||||
return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type)
|
||||
|
||||
|
||||
def atom_red_typechecking_impl(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
op: str,
|
||||
def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, op: str,
|
||||
builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]:
|
||||
if not ptr.type.scalar.is_ptr():
|
||||
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
|
||||
@@ -1136,12 +1053,7 @@ def atom_red_typechecking_impl(ptr: tl.tensor,
|
||||
return ptr, val, mask
|
||||
|
||||
|
||||
def atomic_max(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
scope: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def atomic_max(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder)
|
||||
sem = _str_to_sem(sem)
|
||||
scope = _str_to_scope(scope)
|
||||
@@ -1149,21 +1061,11 @@ def atomic_max(ptr: tl.tensor,
|
||||
# direct call to atomic_max for integers
|
||||
if sca_ty.is_int():
|
||||
if sca_ty.is_int_signed():
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX,
|
||||
ptr.handle,
|
||||
val.handle,
|
||||
mask.handle,
|
||||
sem,
|
||||
scope),
|
||||
val.type)
|
||||
return tl.tensor(
|
||||
builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
|
||||
else:
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX,
|
||||
ptr.handle,
|
||||
val.handle,
|
||||
mask.handle,
|
||||
sem,
|
||||
scope),
|
||||
val.type)
|
||||
return tl.tensor(
|
||||
builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
|
||||
# for float
|
||||
# return atomic_smax(i_ptr, i_val) if val >= 0
|
||||
# return atomic_umin(i_ptr, i_val) if val < 0
|
||||
@@ -1177,18 +1079,17 @@ def atomic_max(ptr: tl.tensor,
|
||||
i_ptr = bitcast(ptr, tl.pointer_type(itype, 1), builder)
|
||||
pos = greater_equal(val, zero, builder)
|
||||
neg = less_than(val, zero, builder)
|
||||
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle, sem, scope), i_val.type)
|
||||
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle, sem, scope), i_val.type)
|
||||
pos_ret = tl.tensor(
|
||||
builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle,
|
||||
and_(mask, pos, builder).handle, sem, scope), i_val.type)
|
||||
neg_ret = tl.tensor(
|
||||
builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle,
|
||||
and_(mask, neg, builder).handle, sem, scope), i_val.type)
|
||||
ret = where(pos, pos_ret, neg_ret, builder)
|
||||
return bitcast(ret, sca_ty, builder)
|
||||
|
||||
|
||||
def atomic_min(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
scope: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def atomic_min(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder)
|
||||
sem = _str_to_sem(sem)
|
||||
scope = _str_to_scope(scope)
|
||||
@@ -1196,21 +1097,11 @@ def atomic_min(ptr: tl.tensor,
|
||||
# direct call to atomic_min for integers
|
||||
if sca_ty.is_int():
|
||||
if sca_ty.is_int_signed():
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN,
|
||||
ptr.handle,
|
||||
val.handle,
|
||||
mask.handle,
|
||||
sem,
|
||||
scope),
|
||||
val.type)
|
||||
return tl.tensor(
|
||||
builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
|
||||
else:
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN,
|
||||
ptr.handle,
|
||||
val.handle,
|
||||
mask.handle,
|
||||
sem,
|
||||
scope),
|
||||
val.type)
|
||||
return tl.tensor(
|
||||
builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
|
||||
# for float
|
||||
# return atomic_smin(i_ptr, i_val) if val >= 0
|
||||
# return atomic_umax(i_ptr, i_val) if val < 0
|
||||
@@ -1224,30 +1115,17 @@ def atomic_min(ptr: tl.tensor,
|
||||
i_ptr = bitcast(ptr, tl.pointer_type(itype, 1), builder)
|
||||
pos = greater_equal(val, zero, builder)
|
||||
neg = less_than(val, zero, builder)
|
||||
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN,
|
||||
i_ptr.handle,
|
||||
i_val.handle,
|
||||
and_(mask, pos, builder).handle,
|
||||
sem,
|
||||
scope),
|
||||
i_val.type)
|
||||
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX,
|
||||
i_ptr.handle,
|
||||
i_val.handle,
|
||||
and_(mask, neg, builder).handle,
|
||||
sem,
|
||||
scope),
|
||||
i_val.type)
|
||||
pos_ret = tl.tensor(
|
||||
builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, i_ptr.handle, i_val.handle,
|
||||
and_(mask, pos, builder).handle, sem, scope), i_val.type)
|
||||
neg_ret = tl.tensor(
|
||||
builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, i_ptr.handle, i_val.handle,
|
||||
and_(mask, neg, builder).handle, sem, scope), i_val.type)
|
||||
ret = where(pos, pos_ret, neg_ret, builder)
|
||||
return bitcast(ret, sca_ty, builder)
|
||||
|
||||
|
||||
def atomic_add(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
scope: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def atomic_add(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder)
|
||||
sem = _str_to_sem(sem)
|
||||
scope = _str_to_scope(scope)
|
||||
@@ -1256,52 +1134,38 @@ def atomic_add(ptr: tl.tensor,
|
||||
return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
|
||||
|
||||
|
||||
def atomic_and(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
scope: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def atomic_and(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder)
|
||||
sem = _str_to_sem(sem)
|
||||
scope = _str_to_scope(scope)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope),
|
||||
val.type)
|
||||
|
||||
|
||||
def atomic_or(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
scope: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def atomic_or(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder)
|
||||
sem = _str_to_sem(sem)
|
||||
scope = _str_to_scope(scope)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope),
|
||||
val.type)
|
||||
|
||||
|
||||
def atomic_xor(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
scope: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def atomic_xor(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder)
|
||||
sem = _str_to_sem(sem)
|
||||
scope = _str_to_scope(scope)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope),
|
||||
val.type)
|
||||
|
||||
|
||||
def atomic_xchg(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
scope: str,
|
||||
def atomic_xchg(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder)
|
||||
sem = _str_to_sem(sem)
|
||||
scope = _str_to_scope(scope)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope),
|
||||
val.type)
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
# Linear Algebra
|
||||
@@ -1321,13 +1185,9 @@ def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def dot(lhs: tl.tensor,
|
||||
rhs: tl.tensor,
|
||||
acc: tl.tensor,
|
||||
allow_tf32: bool,
|
||||
max_num_imprecise_acc: int,
|
||||
out_dtype: tl.dtype,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_num_imprecise_acc: int,
|
||||
out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
|
||||
|
||||
def assert_dtypes_valid(lhs_dtype, rhs_dtype, target):
|
||||
# Checks for non-cuda archs
|
||||
if not _is_cuda(target):
|
||||
@@ -1335,22 +1195,30 @@ def dot(lhs: tl.tensor,
|
||||
return
|
||||
# Checks for cuda arch
|
||||
if target.capability < 90:
|
||||
assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv(), "Dot op does not support fp8e4nv on CUDA arch < 90"
|
||||
assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv(
|
||||
), "Dot op does not support fp8e4nv on CUDA arch < 90"
|
||||
if lhs_dtype.is_fp8() and rhs_dtype.is_fp8():
|
||||
return
|
||||
assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!"
|
||||
else:
|
||||
assert not lhs_dtype.is_fp8e4b15() and not rhs_dtype.is_fp8e4b15(), "Dot op does not support fp8e4b15 on CUDA arch >= 90"
|
||||
assert not lhs_dtype.is_fp8e4b15x4() and not rhs_dtype.is_fp8e4b15x4(), "Dot op does not support fp8e4b15x4 on CUDA arch >= 90"
|
||||
assert not lhs_dtype.is_fp8e4b15() and not rhs_dtype.is_fp8e4b15(
|
||||
), "Dot op does not support fp8e4b15 on CUDA arch >= 90"
|
||||
assert not lhs_dtype.is_fp8e4b15x4() and not rhs_dtype.is_fp8e4b15x4(
|
||||
), "Dot op does not support fp8e4b15x4 on CUDA arch >= 90"
|
||||
if lhs_dtype.is_int() or rhs_dtype.is_int():
|
||||
assert lhs_dtype == rhs_dtype, f"Both operands must be same type. First operand ({lhs_dtype}) and second operand ({rhs_dtype})"
|
||||
assert lhs_dtype.is_int8() or lhs_dtype.is_uint8(), f"Both operands must be either int8 or uint8. Operand type ({lhs_dtype})"
|
||||
assert lhs_dtype.is_int8() or lhs_dtype.is_uint8(
|
||||
), f"Both operands must be either int8 or uint8. Operand type ({lhs_dtype})"
|
||||
elif lhs_dtype.is_fp8() or rhs_dtype.is_fp8():
|
||||
assert lhs_dtype.is_fp8e4nv() or lhs_dtype.is_fp8e5(), f"Only supports fp8e4nv or fp8e5. First operand ({lhs_dtype})"
|
||||
assert rhs_dtype.is_fp8e4nv() or rhs_dtype.is_fp8e5(), f"Only supports fp8e4nv or fp8e5. Second operand ({rhs_dtype})"
|
||||
assert lhs_dtype.is_fp8e4nv() or lhs_dtype.is_fp8e5(
|
||||
), f"Only supports fp8e4nv or fp8e5. First operand ({lhs_dtype})"
|
||||
assert rhs_dtype.is_fp8e4nv() or rhs_dtype.is_fp8e5(
|
||||
), f"Only supports fp8e4nv or fp8e5. Second operand ({rhs_dtype})"
|
||||
else:
|
||||
assert lhs_dtype.is_fp16() or lhs_dtype.is_bf16() or lhs_dtype.is_fp32() or lhs_dtype.is_int1(), f"Unsupported dtype {lhs_dtype}"
|
||||
assert rhs_dtype.is_fp16() or rhs_dtype.is_bf16() or rhs_dtype.is_fp32() or rhs_dtype.is_int1(), f"Unsupported dtype {rhs_dtype}"
|
||||
assert lhs_dtype.is_fp16() or lhs_dtype.is_bf16() or lhs_dtype.is_fp32() or lhs_dtype.is_int1(
|
||||
), f"Unsupported dtype {lhs_dtype}"
|
||||
assert rhs_dtype.is_fp16() or rhs_dtype.is_bf16() or rhs_dtype.is_fp32() or rhs_dtype.is_int1(
|
||||
), f"Unsupported dtype {rhs_dtype}"
|
||||
assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!"
|
||||
|
||||
assert lhs.type.is_block() and rhs.type.is_block()
|
||||
@@ -1359,7 +1227,8 @@ def dot(lhs: tl.tensor,
|
||||
|
||||
assert len(lhs.shape) == 2, f"First input shape ({lhs.shape}) is not two dimensional!"
|
||||
assert len(rhs.shape) == 2, f"Second input shape ({rhs.shape}) is not two dimensional!"
|
||||
assert lhs.shape[1].value == rhs.shape[0].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[1].value}) must be equal to first index of second shape ({rhs.shape[0].value})"
|
||||
assert lhs.shape[1].value == rhs.shape[
|
||||
0].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[1].value}) must be equal to first index of second shape ({rhs.shape[0].value})"
|
||||
assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \
|
||||
and rhs.shape[1].value >= 16, \
|
||||
f"All values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!"
|
||||
@@ -1370,7 +1239,8 @@ def dot(lhs: tl.tensor,
|
||||
_0 = builder.get_int32(0)
|
||||
ret_scalar_ty = tl.int32
|
||||
elif out_dtype.is_bf16():
|
||||
raise ValueError("out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`")
|
||||
raise ValueError(
|
||||
"out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`")
|
||||
elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16():
|
||||
_0 = builder.get_fp32(0)
|
||||
ret_scalar_ty = tl.float32
|
||||
@@ -1391,10 +1261,10 @@ def dot(lhs: tl.tensor,
|
||||
else:
|
||||
_0 = builder.create_splat(builder.get_fp32(0), [M, N])
|
||||
ret_ty = tl.block_type(ret_cast_scalar_ty, [M, N])
|
||||
ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
|
||||
ret_ty)
|
||||
ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), ret_ty)
|
||||
return cast(ret, ret_scalar_ty, builder)
|
||||
if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty) and ret_scalar_ty.primitive_bitwidth < 32:
|
||||
if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32,
|
||||
ret_scalar_ty) and ret_scalar_ty.primitive_bitwidth < 32:
|
||||
if lhs.type.scalar.is_int():
|
||||
ret_dot_scalar_ty = tl.int32
|
||||
_0 = builder.create_splat(builder.get_int32(0), [M, N])
|
||||
@@ -1402,8 +1272,7 @@ def dot(lhs: tl.tensor,
|
||||
ret_dot_scalar_ty = tl.float32
|
||||
_0 = builder.create_splat(builder.get_fp32(0), [M, N])
|
||||
ret_ty = tl.block_type(ret_dot_scalar_ty, [M, N])
|
||||
ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
|
||||
ret_ty)
|
||||
ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), ret_ty)
|
||||
return cast(ret, ret_scalar_ty, builder)
|
||||
ret_ty = tl.block_type(ret_scalar_ty, [M, N])
|
||||
if acc is None:
|
||||
@@ -1413,23 +1282,21 @@ def dot(lhs: tl.tensor,
|
||||
assert acc.type == ret_ty
|
||||
|
||||
# max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90
|
||||
if not (_is_cuda(builder.target) and builder.target.capability == 90 and lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and ret_scalar_ty.is_fp32()):
|
||||
if not (_is_cuda(builder.target) and builder.target.capability == 90 and lhs.dtype.is_fp8() and rhs.dtype.is_fp8()
|
||||
and ret_scalar_ty.is_fp32()):
|
||||
max_num_imprecise_acc = 0
|
||||
if max_num_imprecise_acc is None:
|
||||
max_num_imprecise_acc = 2**30
|
||||
|
||||
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, allow_tf32, max_num_imprecise_acc),
|
||||
ret_ty)
|
||||
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, allow_tf32, max_num_imprecise_acc), ret_ty)
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
# Indexing
|
||||
# ===----------------------------------------------------------------------===//
|
||||
|
||||
def where(condition: tl.tensor,
|
||||
x: tl.tensor,
|
||||
y: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
|
||||
def where(condition: tl.tensor, x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
condition = cast(condition, tl.int1, builder)
|
||||
if condition.type.is_block():
|
||||
condition, x = broadcast_impl_value(condition, x, builder)
|
||||
@@ -1442,14 +1309,13 @@ def where(condition: tl.tensor,
|
||||
ret_ty = x.type
|
||||
return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
# Reduction
|
||||
# ===----------------------------------------------------------------------===
|
||||
|
||||
|
||||
def reduction(
|
||||
inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder
|
||||
) -> Tuple[tl.tensor, ...]:
|
||||
def reduction(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder) -> Tuple[tl.tensor, ...]:
|
||||
if axis is None:
|
||||
new_inputs = []
|
||||
for i in range(len(inputs)):
|
||||
@@ -1475,10 +1341,7 @@ def reduction(
|
||||
region_builder_fn(reduce_op)
|
||||
reduce_op.verify()
|
||||
|
||||
return tuple(
|
||||
wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar)
|
||||
for i in range(len(inputs))
|
||||
)
|
||||
return tuple(wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar) for i in range(len(inputs)))
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===
|
||||
@@ -1486,9 +1349,8 @@ def reduction(
|
||||
# ===----------------------------------------------------------------------===
|
||||
|
||||
|
||||
def associative_scan(
|
||||
inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder
|
||||
) -> Tuple[tl.tensor, ...]:
|
||||
def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn,
|
||||
builder: ir.builder) -> Tuple[tl.tensor, ...]:
|
||||
if len(inputs) != 1:
|
||||
raise ValueError("Current implementation only support single tensor input")
|
||||
shape = inputs[0].type.shape
|
||||
@@ -1501,16 +1363,14 @@ def associative_scan(
|
||||
region_builder_fn(scan_op)
|
||||
scan_op.verify()
|
||||
|
||||
return tuple(
|
||||
wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar)
|
||||
for i in range(len(inputs))
|
||||
)
|
||||
return tuple(wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar) for i in range(len(inputs)))
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===
|
||||
# Math
|
||||
# ===----------------------------------------------------------------------===
|
||||
|
||||
|
||||
def _check_dtype(dtypes: List[str]) -> T:
|
||||
"""
|
||||
We're following libdevice's convention to check accepted data types for math functions.
|
||||
@@ -1519,7 +1379,9 @@ def _check_dtype(dtypes: List[str]) -> T:
|
||||
We should let the users know that they are using and invoke explicit cast to convert
|
||||
the data type to the supported one.
|
||||
"""
|
||||
|
||||
def wrapper(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def check(*args, **kwargs):
|
||||
# concatenate args and kwargs
|
||||
@@ -1528,6 +1390,7 @@ def _check_dtype(dtypes: List[str]) -> T:
|
||||
if arg.type.scalar.name not in dtypes:
|
||||
raise ValueError(f"Expected dtype {dtypes} but got {arg.type.scalar.name}")
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return check
|
||||
|
||||
return wrapper
|
||||
@@ -1631,8 +1494,8 @@ def device_print(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.
|
||||
def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: int, builder: ir.builder) -> tl.tensor:
|
||||
cond_ty = cond.type
|
||||
if not cond_ty.is_block():
|
||||
cond_ty = tl.block_type(cond_ty.scalar, (1,))
|
||||
cond = tl.tensor(builder.create_splat(cond.handle, (1,)), cond_ty)
|
||||
cond_ty = tl.block_type(cond_ty.scalar, (1, ))
|
||||
cond = tl.tensor(builder.create_splat(cond.handle, (1, )), cond_ty)
|
||||
return tl.tensor(builder.create_assert(cond.handle, msg, file_name, func_name, lineno), tl.void)
|
||||
|
||||
|
||||
|
||||
@@ -123,6 +123,7 @@ def maximum(x, y):
|
||||
"""
|
||||
return math.max(x, y)
|
||||
|
||||
|
||||
# max and argmax
|
||||
|
||||
|
||||
@@ -149,8 +150,7 @@ def _argmax_combine_tie_break_fast(value1, index1, value2, index2):
|
||||
|
||||
|
||||
@jit
|
||||
@core._add_reduction_docstr("maximum",
|
||||
return_indices_arg="return_indices",
|
||||
@core._add_reduction_docstr("maximum", return_indices_arg="return_indices",
|
||||
tie_break_arg="return_indices_tie_break_left")
|
||||
def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True):
|
||||
input = core._promote_reduction_input(input)
|
||||
@@ -175,6 +175,7 @@ def argmax(input, axis, tie_break_left=True):
|
||||
(_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left)
|
||||
return ret
|
||||
|
||||
|
||||
# min and argmin
|
||||
|
||||
|
||||
@@ -201,8 +202,7 @@ def _argmin_combine_tie_break_fast(value1, index1, value2, index2):
|
||||
|
||||
|
||||
@jit
|
||||
@core._add_reduction_docstr("minimum",
|
||||
return_indices_arg="return_indices",
|
||||
@core._add_reduction_docstr("minimum", return_indices_arg="return_indices",
|
||||
tie_break_arg="return_indices_tie_break_left")
|
||||
def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True):
|
||||
input = core._promote_reduction_input(input)
|
||||
@@ -222,8 +222,7 @@ def min(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr
|
||||
|
||||
|
||||
@jit
|
||||
@core._add_reduction_docstr("minimum index",
|
||||
tie_break_arg="tie_break_left")
|
||||
@core._add_reduction_docstr("minimum index", tie_break_arg="tie_break_left")
|
||||
def argmin(input, axis, tie_break_left=True):
|
||||
_, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left)
|
||||
return ret
|
||||
@@ -233,6 +232,7 @@ def argmin(input, axis, tie_break_left=True):
|
||||
def _sum_combine(a, b):
|
||||
return a + b
|
||||
|
||||
|
||||
# sum
|
||||
|
||||
|
||||
@@ -247,6 +247,7 @@ def sum(input, axis=None):
|
||||
def _xor_combine(a, b):
|
||||
return a ^ b
|
||||
|
||||
|
||||
# xor sum
|
||||
|
||||
|
||||
@@ -258,8 +259,8 @@ def xor_sum(input, axis=None, _builder=None, _generator=None):
|
||||
raise ValueError("xor_sum only supported for integers")
|
||||
|
||||
input = core._promote_reduction_input(input, _builder=_builder)
|
||||
return core.reduce(input, axis, _xor_combine,
|
||||
_builder=_builder, _generator=_generator)
|
||||
return core.reduce(input, axis, _xor_combine, _builder=_builder, _generator=_generator)
|
||||
|
||||
|
||||
# cumsum
|
||||
|
||||
@@ -271,6 +272,7 @@ def cumsum(input, axis=0):
|
||||
input = core._promote_reduction_input(input)
|
||||
return core.associative_scan(input, axis, _sum_combine)
|
||||
|
||||
|
||||
# cumprod
|
||||
|
||||
|
||||
|
||||
@@ -17,15 +17,14 @@ from ... import language as tl
|
||||
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
|
||||
})
|
||||
@jit
|
||||
def _sdd_kernel(
|
||||
A, B, C,
|
||||
stride_za, stride_ha, stride_ma, stride_ak,
|
||||
stride_zb, stride_hb, stride_bk, stride_nb,
|
||||
stride_zc, stride_hc, stride_mc, stride_nc,
|
||||
K, grid_offset, lut,
|
||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
||||
BLOCK: tl.constexpr, EVEN_K: tl.constexpr
|
||||
):
|
||||
def _sdd_kernel(A, B, C, #
|
||||
stride_za, stride_ha, stride_ma, stride_ak, #
|
||||
stride_zb, stride_hb, stride_bk, stride_nb, #
|
||||
stride_zc, stride_hc, stride_mc, stride_nc, #
|
||||
K, grid_offset, lut, #
|
||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, #
|
||||
BLOCK: tl.constexpr, EVEN_K: tl.constexpr #
|
||||
):
|
||||
# ------------ #
|
||||
# - Prologue - #
|
||||
# ------------ #
|
||||
@@ -104,13 +103,13 @@ def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=
|
||||
c = out
|
||||
grid = [c.shape[1], 1, c.shape[0]]
|
||||
_sdd_kernel[grid](
|
||||
a, b, c,
|
||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
|
||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
||||
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
|
||||
Ka, 0, lut,
|
||||
TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4,
|
||||
num_warps=4,
|
||||
a, b, c, #
|
||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), #
|
||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), #
|
||||
c.stride(0), c.stride(1), c.stride(2), c.stride(3), #
|
||||
Ka, 0, lut, #
|
||||
TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4, #
|
||||
num_warps=4 #
|
||||
)
|
||||
return c
|
||||
|
||||
@@ -120,6 +119,7 @@ def sdd_lut(layout, block, device):
|
||||
lut = lut.contiguous()
|
||||
return lut, None
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Dense = Sparse x Dense (DSD)
|
||||
# This operation uses a look-up table that contains pre-computed pointer increments
|
||||
@@ -128,15 +128,14 @@ def sdd_lut(layout, block, device):
|
||||
|
||||
|
||||
@jit
|
||||
def _dsd_kernel(
|
||||
A, B, C,
|
||||
stride_az, stride_ha, stride_am, stride_ak,
|
||||
stride_zb, stride_hb, stride_bk, stride_bn,
|
||||
stride_zc, stride_hc, stride_cm, stride_cn,
|
||||
DS0, DS1, lut,
|
||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr
|
||||
):
|
||||
def _dsd_kernel(A, B, C, #
|
||||
stride_az, stride_ha, stride_am, stride_ak, #
|
||||
stride_zb, stride_hb, stride_bk, stride_bn, #
|
||||
stride_zc, stride_hc, stride_cm, stride_cn, #
|
||||
DS0, DS1, lut, #
|
||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, #
|
||||
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr #
|
||||
):
|
||||
# ------------ #
|
||||
# - Prologue - #
|
||||
# ------------ #
|
||||
@@ -229,13 +228,13 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=N
|
||||
# compute output
|
||||
grid = lambda meta: [cdiv(BS3, meta['TILE_N']), width, BS0]
|
||||
_dsd_kernel[grid](
|
||||
a, b, c,
|
||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
|
||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
||||
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
|
||||
BS3, AS1, lut,
|
||||
TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
|
||||
num_warps=4, GROUP_SIZE_M=4,
|
||||
a, b, c, #
|
||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), #
|
||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), #
|
||||
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), #
|
||||
BS3, AS1, lut, #
|
||||
TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4, #
|
||||
num_warps=4, GROUP_SIZE_M=4 #
|
||||
)
|
||||
# exit()
|
||||
return c
|
||||
@@ -337,6 +336,7 @@ def dsd_lut(layout, block, step, trans, device):
|
||||
# create locks
|
||||
return lut, width
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Dense = Dense x Sparse (DDS)
|
||||
# -----------------------------
|
||||
@@ -346,6 +346,7 @@ def dsd_lut(layout, block, step, trans, device):
|
||||
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
|
||||
return dsd_matmul(b, a, not trans_b, not trans_a, not trans_c, spdims, block, lut, width, out=out)
|
||||
|
||||
|
||||
##############
|
||||
# MAIN API #
|
||||
##############
|
||||
@@ -356,10 +357,8 @@ class _matmul(torch.autograd.Function):
|
||||
fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul}
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block,
|
||||
c_lut, c_width, da_lut, da_width, db_lut, db_width, out
|
||||
):
|
||||
def forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_width, da_lut, da_width, db_lut,
|
||||
db_width, out):
|
||||
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out)
|
||||
# save for backward
|
||||
ctx.save_for_backward(a, b)
|
||||
@@ -385,15 +384,13 @@ class _matmul(torch.autograd.Function):
|
||||
# gradients w.r.t. a
|
||||
if ctx.needs_input_grad[0]:
|
||||
mode_da = mode[1] + mode[0] + mode[2]
|
||||
da = _matmul.fn[mode_da](
|
||||
dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_width,
|
||||
)
|
||||
da = _matmul.fn[mode_da](dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block,
|
||||
ctx.da_lut, ctx.da_width)
|
||||
# gradients w.r.t. b
|
||||
if ctx.needs_input_grad[1]:
|
||||
mode_db = mode[2] + mode[1] + mode[0]
|
||||
db = _matmul.fn[mode_db](
|
||||
a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_width,
|
||||
)
|
||||
db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block,
|
||||
ctx.db_lut, ctx.db_width)
|
||||
dout = dc if ctx.has_out else None
|
||||
return da, db, None, None, None, \
|
||||
None, None, None, None, \
|
||||
@@ -427,11 +424,9 @@ class matmul:
|
||||
self.db_lut, self.db_width = sdd_lut(layout, block, device)
|
||||
|
||||
def __call__(self, a, b, out=None):
|
||||
c = _matmul.apply(
|
||||
a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block,
|
||||
self.c_lut, self.c_width,
|
||||
self.da_lut, self.da_width,
|
||||
self.db_lut, self.db_width,
|
||||
out
|
||||
)
|
||||
c = _matmul.apply(a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, #
|
||||
self.c_lut, self.c_width, #
|
||||
self.da_lut, self.da_width, #
|
||||
self.db_lut, self.db_width, #
|
||||
out)
|
||||
return c
|
||||
|
||||
@@ -18,14 +18,13 @@ def num_warps(n):
|
||||
|
||||
|
||||
@jit
|
||||
def _blocksparse_softmax_fwd(
|
||||
Out, A, stride_xz, LUT,
|
||||
R, extent, stride_zr, stride_hr, # relative attention
|
||||
scale, is_causal,
|
||||
ROW_SIZE: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
IS_DENSE: tl.constexpr,
|
||||
):
|
||||
def _blocksparse_softmax_fwd(Out, A, stride_xz, LUT, #
|
||||
R, extent, stride_zr, stride_hr, # relative attention
|
||||
scale, is_causal, #
|
||||
ROW_SIZE: tl.constexpr, #
|
||||
BLOCK_SIZE: tl.constexpr, #
|
||||
IS_DENSE: tl.constexpr #
|
||||
):
|
||||
h = tl.program_id(0)
|
||||
m = tl.program_id(1)
|
||||
z = tl.program_id(2)
|
||||
@@ -73,18 +72,16 @@ def _blocksparse_softmax_fwd(
|
||||
|
||||
|
||||
@jit
|
||||
def _blocksparse_softmax_bwd(
|
||||
DA, stride_zdx,
|
||||
DOut, stride_zdout,
|
||||
Out, stride_zout,
|
||||
scale,
|
||||
LUT,
|
||||
DR, extent, stride_zr, stride_hr, stride_er,
|
||||
is_causal,
|
||||
ROW_SIZE: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
IS_DENSE: tl.constexpr,
|
||||
):
|
||||
def _blocksparse_softmax_bwd(DA, stride_zdx, #
|
||||
DOut, stride_zdout, #
|
||||
Out, stride_zout, #
|
||||
scale, #
|
||||
LUT, #
|
||||
DR, extent, stride_zr, stride_hr, stride_er, #
|
||||
is_causal, #
|
||||
ROW_SIZE: tl.constexpr, #
|
||||
BLOCK_SIZE: tl.constexpr, #
|
||||
IS_DENSE: tl.constexpr):
|
||||
h = tl.program_id(0)
|
||||
m = tl.program_id(1)
|
||||
z = tl.program_id(2)
|
||||
@@ -133,6 +130,7 @@ def _blocksparse_softmax_bwd(
|
||||
|
||||
|
||||
class _softmax(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def make_lut(layout, block, device):
|
||||
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
|
||||
@@ -151,10 +149,7 @@ class _softmax(torch.autograd.Function):
|
||||
return lut, int(total_sizes.max())
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, a, scale, rel_logits, is_causal,
|
||||
spdims, block, lut, maxlut, is_dense
|
||||
):
|
||||
def forward(ctx, a, scale, rel_logits, is_causal, spdims, block, lut, maxlut, is_dense):
|
||||
if scale is not None and isinstance(scale, torch.Tensor):
|
||||
assert scale.device.type == "cpu"
|
||||
scale = scale.item()
|
||||
@@ -165,14 +160,14 @@ class _softmax(torch.autograd.Function):
|
||||
# enqueue kernel
|
||||
out = torch.empty_like(a)
|
||||
_blocksparse_softmax_fwd[grid](
|
||||
out, a, a.stride(0), lut,
|
||||
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn
|
||||
scale,
|
||||
is_causal,
|
||||
BLOCK_SIZE=block,
|
||||
ROW_SIZE=next_power_of_2(maxlut),
|
||||
IS_DENSE=is_dense,
|
||||
num_warps=num_warps(maxlut)
|
||||
out, a, a.stride(0), lut, #
|
||||
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn#
|
||||
scale, #
|
||||
is_causal, #
|
||||
BLOCK_SIZE=block, #
|
||||
ROW_SIZE=next_power_of_2(maxlut), #
|
||||
IS_DENSE=is_dense, #
|
||||
num_warps=num_warps(maxlut) #
|
||||
)
|
||||
# save to context
|
||||
# ctx.mark_dirty(x)
|
||||
@@ -201,28 +196,23 @@ class _softmax(torch.autograd.Function):
|
||||
grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M)
|
||||
da = torch.empty_like(dout)
|
||||
_blocksparse_softmax_bwd[grid](
|
||||
da, da.stride(0),
|
||||
dout, dout.stride(0),
|
||||
out, out.stride(0),
|
||||
ctx.scale,
|
||||
lut,
|
||||
dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2],
|
||||
ctx.is_causal,
|
||||
BLOCK_SIZE=ctx.block,
|
||||
ROW_SIZE=next_power_of_2(ctx.maxlut),
|
||||
IS_DENSE=ctx.is_dense,
|
||||
num_warps=num_warps(ctx.maxlut)
|
||||
da, da.stride(0), #
|
||||
dout, dout.stride(0), #
|
||||
out, out.stride(0), #
|
||||
ctx.scale, #
|
||||
lut, #
|
||||
dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2], #
|
||||
ctx.is_causal, #
|
||||
BLOCK_SIZE=ctx.block, #
|
||||
ROW_SIZE=next_power_of_2(ctx.maxlut), #
|
||||
IS_DENSE=ctx.is_dense, #
|
||||
num_warps=num_warps(ctx.maxlut) #
|
||||
)
|
||||
return (da, None, None, dr, None,
|
||||
None, None, None, None, None,
|
||||
None,
|
||||
None, None, None,
|
||||
None,
|
||||
None, None, None
|
||||
)
|
||||
return (da, None, None, dr, None, None, None, None, None, None, None, None, None, None, None, None, None, None)
|
||||
|
||||
|
||||
class softmax:
|
||||
|
||||
def __init__(self, layout, block, device, is_dense=False):
|
||||
self.spdims = layout.shape
|
||||
self.layout = layout
|
||||
@@ -233,8 +223,6 @@ class softmax:
|
||||
def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False):
|
||||
if rel_logits is not None and rel_logits.dtype != a.dtype:
|
||||
raise ValueError(f"relative position embedding must be {a.dtype}")
|
||||
a = _softmax.apply(
|
||||
a, scale, rel_logits, is_causal,
|
||||
self.spdims, self.block, self.lut, self.maxlut, self.is_dense,
|
||||
)
|
||||
a = _softmax.apply(a, scale, rel_logits, is_causal, self.spdims, self.block, self.lut, self.maxlut,
|
||||
self.is_dense)
|
||||
return a
|
||||
|
||||
@@ -59,6 +59,7 @@ def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
|
||||
|
||||
|
||||
class _cross_entropy(torch.autograd.Function):
|
||||
|
||||
@classmethod
|
||||
def forward(cls, ctx, logits, indices):
|
||||
# make sure we can use triton
|
||||
|
||||
@@ -15,22 +15,19 @@ from .. import language as tl
|
||||
|
||||
|
||||
@jit
|
||||
def _fwd_kernel(
|
||||
# fmt: off
|
||||
Q, K, V, sm_scale,
|
||||
L,
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, H, N_CTX,
|
||||
Z_H_N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
# fmt: on
|
||||
):
|
||||
def _fwd_kernel(Q, K, V, sm_scale, #
|
||||
L, #
|
||||
Out, #
|
||||
stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vn, stride_vk, #
|
||||
stride_oz, stride_oh, stride_om, stride_on, #
|
||||
Z, H, N_CTX, #
|
||||
Z_H_N_CTX, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
|
||||
BLOCK_N: tl.constexpr, #
|
||||
IS_CAUSAL: tl.constexpr #
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
qvk_offset = off_hz * stride_qh
|
||||
@@ -132,27 +129,24 @@ def _bwd_preprocess(
|
||||
|
||||
|
||||
@jit
|
||||
def _bwd_kernel_one_col_block(
|
||||
# fmt: off
|
||||
Q, K, V, sm_scale, qk_scale,
|
||||
Out, DO,
|
||||
DQ, DK, DV,
|
||||
L,
|
||||
D,
|
||||
Q_block_ptr, K_block_ptr, V_block_ptr,
|
||||
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr,
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
Z, H, N_CTX,
|
||||
off_h, off_z, off_hz, start_n, num_block,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
SEQUENCE_PARALLEL: tl.constexpr,
|
||||
CAUSAL: tl.constexpr,
|
||||
MMA_V3: tl.constexpr
|
||||
# fmt: on
|
||||
):
|
||||
def _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, #
|
||||
Out, DO, #
|
||||
DQ, DK, DV, #
|
||||
L, #
|
||||
D, #
|
||||
Q_block_ptr, K_block_ptr, V_block_ptr, #
|
||||
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vn, stride_vk, #
|
||||
Z, H, N_CTX, #
|
||||
off_h, off_z, off_hz, start_n, num_block, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
|
||||
BLOCK_N: tl.constexpr, #
|
||||
SEQUENCE_PARALLEL: tl.constexpr, #
|
||||
CAUSAL: tl.constexpr, #
|
||||
MMA_V3: tl.constexpr #
|
||||
):
|
||||
if CAUSAL:
|
||||
lo = start_n * BLOCK_M
|
||||
else:
|
||||
@@ -235,26 +229,23 @@ def _bwd_kernel_one_col_block(
|
||||
|
||||
|
||||
@jit
|
||||
def _bwd_kernel(
|
||||
# fmt: off
|
||||
Q, K, V, sm_scale,
|
||||
Out, DO,
|
||||
DQ, DK, DV,
|
||||
L,
|
||||
D,
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
Z, H, N_CTX,
|
||||
Z_H_N_CTX,
|
||||
SQ_Z_H_N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
SEQUENCE_PARALLEL: tl.constexpr,
|
||||
CAUSAL: tl.constexpr,
|
||||
MMA_V3: tl.constexpr
|
||||
# fmt: on
|
||||
):
|
||||
def _bwd_kernel(Q, K, V, sm_scale, #
|
||||
Out, DO, #
|
||||
DQ, DK, DV, #
|
||||
L, #
|
||||
D, #
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vn, stride_vk, #
|
||||
Z, H, N_CTX, #
|
||||
Z_H_N_CTX, #
|
||||
SQ_Z_H_N_CTX, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
|
||||
BLOCK_N: tl.constexpr, #
|
||||
SEQUENCE_PARALLEL: tl.constexpr, #
|
||||
CAUSAL: tl.constexpr, #
|
||||
MMA_V3: tl.constexpr #
|
||||
):
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
off_hz = tl.program_id(0)
|
||||
off_z = off_hz // H
|
||||
@@ -331,51 +322,46 @@ def _bwd_kernel(
|
||||
num_block_n = tl.cdiv(N_CTX, BLOCK_N)
|
||||
if not SEQUENCE_PARALLEL:
|
||||
for start_n in range(0, num_block_n):
|
||||
_bwd_kernel_one_col_block(
|
||||
# fmt: off
|
||||
Q, K, V, sm_scale, qk_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
L,
|
||||
D,
|
||||
Q_block_ptr, K_block_ptr, V_block_ptr,
|
||||
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr,
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
Z, H, N_CTX,
|
||||
off_h, off_z, off_hz, start_n, num_block_n,
|
||||
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_N=BLOCK_N,
|
||||
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
|
||||
CAUSAL=CAUSAL,
|
||||
MMA_V3=MMA_V3
|
||||
# fmt: on
|
||||
)
|
||||
_bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, #
|
||||
DQ, DK, DV, #
|
||||
L, #
|
||||
D, #
|
||||
Q_block_ptr, K_block_ptr, V_block_ptr, #
|
||||
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vn, stride_vk, #
|
||||
Z, H, N_CTX, #
|
||||
off_h, off_z, off_hz, start_n, num_block_n, #
|
||||
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, #
|
||||
BLOCK_N=BLOCK_N, #
|
||||
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, #
|
||||
CAUSAL=CAUSAL, #
|
||||
MMA_V3=MMA_V3 #
|
||||
)
|
||||
else:
|
||||
start_n = tl.program_id(1)
|
||||
_bwd_kernel_one_col_block(
|
||||
# fmt: off
|
||||
Q, K, V, sm_scale, qk_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
L,
|
||||
D,
|
||||
Q_block_ptr, K_block_ptr, V_block_ptr,
|
||||
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr,
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
Z, H, N_CTX,
|
||||
off_h, off_z, off_hz, start_n, num_block_n,
|
||||
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_N=BLOCK_N,
|
||||
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
|
||||
CAUSAL=CAUSAL,
|
||||
MMA_V3=MMA_V3
|
||||
# fmt: on
|
||||
)
|
||||
_bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, #
|
||||
DQ, DK, DV, #
|
||||
L, #
|
||||
D, #
|
||||
Q_block_ptr, K_block_ptr, V_block_ptr, #
|
||||
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vn, stride_vk, #
|
||||
Z, H, N_CTX, #
|
||||
off_h, off_z, off_hz, start_n, num_block_n, #
|
||||
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, #
|
||||
BLOCK_N=BLOCK_N, #
|
||||
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, #
|
||||
CAUSAL=CAUSAL, #
|
||||
MMA_V3=MMA_V3 #
|
||||
)
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False):
|
||||
# only support for Ampere now
|
||||
@@ -393,21 +379,19 @@ class _attention(torch.autograd.Function):
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
_fwd_kernel[grid](
|
||||
# fmt: off
|
||||
q, k, v, sm_scale,
|
||||
L,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
q.shape[0] * q.shape[1] * q.shape[2],
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
|
||||
IS_CAUSAL=causal,
|
||||
num_warps=num_warps,
|
||||
num_stages=4,
|
||||
# fmt: on
|
||||
q, k, v, sm_scale, #
|
||||
L, #
|
||||
o, #
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
|
||||
q.shape[0], q.shape[1], q.shape[2], #
|
||||
q.shape[0] * q.shape[1] * q.shape[2], #
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, #
|
||||
IS_CAUSAL=causal, #
|
||||
num_warps=num_warps, #
|
||||
num_stages=4 #
|
||||
)
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, L)
|
||||
@@ -429,14 +413,14 @@ class _attention(torch.autograd.Function):
|
||||
do = do.contiguous()
|
||||
if sequence_parallel:
|
||||
replicas = cdiv(seq_len_kv, BLOCK)
|
||||
new_dq_shape = (replicas,) + q.shape
|
||||
new_dq_shape = (replicas, ) + q.shape
|
||||
dq = torch.zeros(new_dq_shape, device=q.device, dtype=q.dtype)
|
||||
else:
|
||||
dq = torch.zeros_like(q, dtype=q.dtype)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
delta = torch.empty_like(L)
|
||||
_bwd_preprocess[(cdiv(q.shape[2], BLOCK) * ctx.grid[1],)](
|
||||
_bwd_preprocess[(cdiv(q.shape[2], BLOCK) * ctx.grid[1], )](
|
||||
o,
|
||||
do,
|
||||
delta,
|
||||
@@ -444,26 +428,24 @@ class _attention(torch.autograd.Function):
|
||||
D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
_bwd_kernel[(ctx.grid[1], cdiv(seq_len_kv, BLOCK) if sequence_parallel else 1)](
|
||||
# fmt: off
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do,
|
||||
dq, dk, dv,
|
||||
L,
|
||||
delta,
|
||||
o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
q.shape[0] * q.shape[1] * q.shape[2],
|
||||
cdiv(seq_len_kv, BLOCK) * q.shape[0] * q.shape[1] * q.shape[2],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
|
||||
SEQUENCE_PARALLEL=sequence_parallel,
|
||||
CAUSAL=ctx.causal,
|
||||
MMA_V3=MMA_V3,
|
||||
num_warps=8,
|
||||
num_stages=1,
|
||||
# fmt: on
|
||||
q, k, v, ctx.sm_scale, #
|
||||
o, do, #
|
||||
dq, dk, dv, #
|
||||
L, #
|
||||
delta, #
|
||||
o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
|
||||
q.shape[0], q.shape[1], q.shape[2], #
|
||||
q.shape[0] * q.shape[1] * q.shape[2], #
|
||||
cdiv(seq_len_kv, BLOCK) * q.shape[0] * q.shape[1] * q.shape[2], #
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK, #
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, #
|
||||
SEQUENCE_PARALLEL=sequence_parallel, #
|
||||
CAUSAL=ctx.causal, #
|
||||
MMA_V3=MMA_V3, #
|
||||
num_warps=8, #
|
||||
num_stages=1 #
|
||||
)
|
||||
|
||||
if len(dq.shape) == 5:
|
||||
|
||||
@@ -37,8 +37,9 @@ def get_configs_io_bound():
|
||||
num_stages=num_stages, num_warps=num_warps))
|
||||
# split_k
|
||||
for split_k in [2, 4, 8, 16]:
|
||||
configs.append(Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
configs.append(
|
||||
Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
return configs
|
||||
|
||||
|
||||
@@ -69,22 +70,22 @@ def get_configs_io_bound():
|
||||
prune_configs_by={
|
||||
'early_config_prune': early_config_prune,
|
||||
'perf_model': estimate_matmul_time,
|
||||
'top_k': 10
|
||||
'top_k': 10,
|
||||
},
|
||||
)
|
||||
@heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||
})
|
||||
@jit
|
||||
def _kernel(A, B, C, M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
dot_out_dtype: tl.constexpr,
|
||||
allow_tf32: tl.constexpr,
|
||||
fp8_fast_accum: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr
|
||||
def _kernel(A, B, C, M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
dot_out_dtype: tl.constexpr, #
|
||||
allow_tf32: tl.constexpr, #
|
||||
fp8_fast_accum: tl.constexpr, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr #
|
||||
):
|
||||
# matrix multiplication
|
||||
pid = tl.program_id(0)
|
||||
@@ -184,14 +185,15 @@ class _matmul(torch.autograd.Function):
|
||||
ab_dtype = False
|
||||
# launch kernel
|
||||
grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||
_kernel[grid](a, b, c, M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
dot_out_dtype=dot_out_dtype,
|
||||
allow_tf32=allow_tf32,
|
||||
fp8_fast_accum=fp8_fast_accum,
|
||||
GROUP_M=8, AB_DTYPE=ab_dtype)
|
||||
_kernel[grid](
|
||||
a, b, c, M, N, K, #
|
||||
a.stride(0), a.stride(1), #
|
||||
b.stride(0), b.stride(1), #
|
||||
c.stride(0), c.stride(1), #
|
||||
dot_out_dtype=dot_out_dtype, #
|
||||
allow_tf32=allow_tf32, #
|
||||
fp8_fast_accum=fp8_fast_accum, #
|
||||
GROUP_M=8, AB_DTYPE=ab_dtype)
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -5,8 +5,7 @@ import torch
|
||||
from .. import cdiv
|
||||
from .._C.libtriton.triton import runtime
|
||||
from ..runtime import driver
|
||||
from ..testing import (get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops,
|
||||
nvsmi)
|
||||
from ..testing import (get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops, nvsmi)
|
||||
|
||||
|
||||
def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
@@ -14,7 +13,8 @@ def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
total_warps = num_ctas * min(num_warps, 4)
|
||||
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
|
||||
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
|
||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, cur_sm_clock, backend, device)
|
||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(
|
||||
dtype, cur_sm_clock, backend, device)
|
||||
return tflops
|
||||
|
||||
|
||||
@@ -35,12 +35,12 @@ def get_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
|
||||
|
||||
def estimate_matmul_time(
|
||||
# backend, device,
|
||||
num_warps, num_stages,
|
||||
A, B, C,
|
||||
M, N, K,
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K,
|
||||
debug=False, **kwargs
|
||||
# backend, device,
|
||||
num_warps, num_stages, #
|
||||
A, B, C, #
|
||||
M, N, K, #
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, #
|
||||
debug=False, **kwargs #
|
||||
):
|
||||
''' return estimated running time in ms
|
||||
= max(compute, loading) + store '''
|
||||
@@ -149,8 +149,9 @@ def early_config_prune(configs, named_args):
|
||||
optimal_num_stages = ldgsts_latency / mma_cycles
|
||||
|
||||
# nearest stages, prefer large #stages
|
||||
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages)
|
||||
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
|
||||
nearest = heapq.nsmallest(
|
||||
2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages)
|
||||
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
|
||||
|
||||
for n in nearest:
|
||||
pruned_configs.append(n[0])
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune,
|
||||
heuristics)
|
||||
from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune, heuristics)
|
||||
from .driver import driver
|
||||
from .jit import JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret
|
||||
|
||||
|
||||
@@ -9,11 +9,10 @@ from .jit import KernelInterface
|
||||
|
||||
|
||||
class OutOfResources(Exception):
|
||||
|
||||
def __init__(self, required, limit, name):
|
||||
self.message = (
|
||||
f"out of resource: {name}, Required: {required}, Hardware limit: {limit}. "
|
||||
+ "Reducing block sizes or `num_stages` may help."
|
||||
)
|
||||
self.message = (f"out of resource: {name}, Required: {required}, Hardware limit: {limit}. " +
|
||||
"Reducing block sizes or `num_stages` may help.")
|
||||
self.required = required
|
||||
self.limit = limit
|
||||
self.name = name
|
||||
@@ -25,6 +24,7 @@ class OutOfResources(Exception):
|
||||
|
||||
|
||||
class Autotuner(KernelInterface):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fn,
|
||||
@@ -99,10 +99,8 @@ class Autotuner(KernelInterface):
|
||||
# as kwargs and by the autotuner
|
||||
conflicts = meta.keys() & config.kwargs.keys()
|
||||
if conflicts:
|
||||
raise ValueError(
|
||||
f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
||||
" Make sure that you don't re-define auto-tuned symbols."
|
||||
)
|
||||
raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
||||
" Make sure that you don't re-define auto-tuned symbols.")
|
||||
# augment meta-parameters with tunable ones
|
||||
current = dict(meta, **config.kwargs)
|
||||
full_nargs = {**self.nargs, **current}
|
||||
@@ -179,7 +177,8 @@ class Autotuner(KernelInterface):
|
||||
top_k = int(len(self.configs) * top_k)
|
||||
if len(pruned_configs) > top_k:
|
||||
est_timing = {
|
||||
config: self.perf_model(
|
||||
config:
|
||||
self.perf_model(
|
||||
**self.nargs,
|
||||
**kwargs,
|
||||
**config.kwargs,
|
||||
@@ -296,6 +295,7 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_va
|
||||
|
||||
|
||||
class Heuristics(KernelInterface):
|
||||
|
||||
def __init__(self, fn, arg_names, values) -> None:
|
||||
self.fn = fn
|
||||
self.values = values
|
||||
|
||||
@@ -19,6 +19,7 @@ def default_dump_dir():
|
||||
|
||||
|
||||
class CacheManager(ABC):
|
||||
|
||||
def __init__(self, key):
|
||||
pass
|
||||
|
||||
@@ -44,6 +45,7 @@ class CacheManager(ABC):
|
||||
|
||||
|
||||
class FileCacheManager(CacheManager):
|
||||
|
||||
def __init__(self, key, override=False, dump=False):
|
||||
self.key = key
|
||||
self.lock_path = None
|
||||
|
||||
@@ -26,6 +26,7 @@ class DriverBase(metaclass=abc.ABCMeta):
|
||||
|
||||
|
||||
class CudaUtils(object):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(CudaUtils, cls).__new__(cls)
|
||||
@@ -65,6 +66,7 @@ class CudaUtils(object):
|
||||
|
||||
|
||||
class CudaDriver(DriverBase):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(CudaDriver, cls).__new__(cls)
|
||||
@@ -81,6 +83,7 @@ class CudaDriver(DriverBase):
|
||||
|
||||
|
||||
class HIPUtils(object):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(HIPUtils, cls).__new__(cls)
|
||||
@@ -111,6 +114,7 @@ class HIPUtils(object):
|
||||
|
||||
|
||||
class HIPDriver(DriverBase):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(HIPDriver, cls).__new__(cls)
|
||||
@@ -122,6 +126,7 @@ class HIPDriver(DriverBase):
|
||||
|
||||
|
||||
class UnsupportedDriver(DriverBase):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(UnsupportedDriver, cls).__new__(cls)
|
||||
@@ -138,6 +143,7 @@ class UnsupportedDriver(DriverBase):
|
||||
|
||||
|
||||
class LazyProxy:
|
||||
|
||||
def __init__(self, init_fn):
|
||||
self._init_fn = init_fn
|
||||
self._obj = None
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
class OutOfResources(Exception):
|
||||
|
||||
def __init__(self, required, limit, name):
|
||||
self.message = f"out of resource: {name}, " f"Required: {required}, " f"Hardware limit: {limit}"
|
||||
self.message += ". Reducing block sizes or `num_stages` may help."
|
||||
|
||||
@@ -37,6 +37,7 @@ def str_to_ty(name):
|
||||
|
||||
|
||||
class TensorHandle:
|
||||
|
||||
def __init__(self, data, dtype):
|
||||
self.data = data
|
||||
self.dtype = dtype
|
||||
@@ -46,6 +47,7 @@ class TensorHandle:
|
||||
|
||||
|
||||
class BlockPointerHandle:
|
||||
|
||||
def __init__(self, base, shape, strides, offsets, tensor_shape, order):
|
||||
self.base = base
|
||||
self.shape = shape
|
||||
@@ -72,7 +74,9 @@ class BlockPointerHandle:
|
||||
|
||||
|
||||
def wrap_ret(compute_ret_ty):
|
||||
|
||||
def wrapper(fn):
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
ret = fn(*args, **kwargs)
|
||||
return TensorHandle(ret.data, compute_ret_ty(*args, **kwargs))
|
||||
@@ -83,6 +87,7 @@ def wrap_ret(compute_ret_ty):
|
||||
|
||||
|
||||
class Builder:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.arch = None
|
||||
# pass
|
||||
@@ -280,9 +285,8 @@ class Builder:
|
||||
dtype_tt = ptr.dtype.element_ty
|
||||
return TensorHandle(ptr.data + (dtype_tt.primitive_bitwidth // 8) * offset.data.astype(np.uint64), ptr.dtype)
|
||||
|
||||
def create_tensor_pointer_load(
|
||||
self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, is_volatile
|
||||
):
|
||||
def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy,
|
||||
is_volatile):
|
||||
ptrs, masks = ptr.materialize_pointers(boundary_check)
|
||||
assert padding_option is None
|
||||
other = None
|
||||
@@ -364,9 +368,10 @@ class Builder:
|
||||
|
||||
|
||||
def patch_attr(obj, name, member, builder):
|
||||
new_member = lambda *args, member=member, **kwargs: (
|
||||
member(*args, **{k: v for k, v in kwargs.items() if k != "_builder"}, _builder=builder)
|
||||
)
|
||||
new_member = lambda *args, member=member, **kwargs: (member(*args, **
|
||||
{k: v
|
||||
for k, v in kwargs.items()
|
||||
if k != "_builder"}, _builder=builder))
|
||||
setattr(obj, name, new_member)
|
||||
|
||||
|
||||
@@ -412,6 +417,7 @@ def _patch_lang_math(lang, builder):
|
||||
}
|
||||
|
||||
def make_numpy(name):
|
||||
|
||||
def impl(*args, **kwargs):
|
||||
ret_type = args[0].type # TODO: incorrect
|
||||
ret_dtype = args[0].dtype # TODO: incorrect
|
||||
@@ -424,14 +430,13 @@ def _patch_lang_math(lang, builder):
|
||||
return impl
|
||||
|
||||
def make_fallback(name):
|
||||
|
||||
def fallback(*args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
f"""
|
||||
raise NotImplementedError(f"""
|
||||
{name} not supported in interpreter mode: no known numpy implementation.
|
||||
If you think that {name} in fact does have a numpy implementation, please add it
|
||||
to the mapping in python/triton/interpreter/new_interpreter.py:_patch_lang_math.
|
||||
"""
|
||||
)
|
||||
""")
|
||||
|
||||
return fallback
|
||||
|
||||
@@ -467,6 +472,7 @@ RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_warp_specializati
|
||||
|
||||
|
||||
class GridExecutor:
|
||||
|
||||
def __init__(self, fn, arg_names, grid):
|
||||
from .jit import _normalize_ty # TODO: modularize
|
||||
|
||||
@@ -496,7 +502,7 @@ class GridExecutor:
|
||||
# iterate through grid
|
||||
grid = self.grid(args) if callable(self.grid) else self.grid
|
||||
assert len(grid) <= 3
|
||||
grid = grid + (1,) * (3 - len(grid))
|
||||
grid = grid + (1, ) * (3 - len(grid))
|
||||
builder.set_grid_dim(*grid)
|
||||
for x in range(grid[0]):
|
||||
for y in range(grid[1]):
|
||||
@@ -510,6 +516,7 @@ class GridExecutor:
|
||||
|
||||
|
||||
class InterpretedFunction:
|
||||
|
||||
def _patch_lang(self, builder):
|
||||
lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]]
|
||||
assert len(lang) == 1, "triton.language must be visible from within jit'd function"
|
||||
|
||||
@@ -72,9 +72,8 @@ class DependenciesFinder(ast.NodeVisitor):
|
||||
lhs = self.visit(node.value)
|
||||
while isinstance(lhs, ast.Attribute):
|
||||
lhs = self.visit(lhs.value)
|
||||
if lhs is None or (
|
||||
getattr(lhs, "__name__", "") == "triton" or getattr(lhs, "__name__", "").endswith(".triton")
|
||||
):
|
||||
if lhs is None or (getattr(lhs, "__name__", "") == "triton"
|
||||
or getattr(lhs, "__name__", "").endswith(".triton")):
|
||||
return None
|
||||
return getattr(lhs, node.attr)
|
||||
|
||||
@@ -176,7 +175,7 @@ class KernelArg:
|
||||
assert not self.param.do_not_specialize
|
||||
|
||||
try:
|
||||
return (self.value.data_ptr() % JITFunction.divisibility == 0,)
|
||||
return (self.value.data_ptr() % JITFunction.divisibility == 0, )
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@@ -188,7 +187,7 @@ class KernelArg:
|
||||
self.value == 1,
|
||||
)
|
||||
|
||||
return (False,)
|
||||
return (False, )
|
||||
|
||||
|
||||
class KernelInterface(Generic[T]):
|
||||
@@ -253,10 +252,11 @@ class JITFunction(KernelInterface[T]):
|
||||
return arg.data_ptr() % JITFunction.divisibility == 0
|
||||
elif isinstance(arg, int):
|
||||
return (arg % 16 == 0, arg == 1)
|
||||
return (arg is None,)
|
||||
return (arg is None, )
|
||||
|
||||
# TODO(jlebar): Fold this into the KernelArg class.
|
||||
def _get_config(self, *args):
|
||||
|
||||
def is_divisible_by_16(x):
|
||||
if hasattr(x, "data_ptr"):
|
||||
return x.data_ptr() % JITFunction.divisibility == 0
|
||||
@@ -279,7 +279,9 @@ class JITFunction(KernelInterface[T]):
|
||||
if is_divisible_by_16(arg) and not param.do_not_specialize
|
||||
}
|
||||
divisible_by_8 = {
|
||||
param.num for param, arg in zip(self.params, args) if is_divisible_by_8(arg) and not param.do_not_specialize
|
||||
param.num
|
||||
for param, arg in zip(self.params, args)
|
||||
if is_divisible_by_8(arg) and not param.do_not_specialize
|
||||
}
|
||||
equal_to_1 = {
|
||||
param.num
|
||||
@@ -290,9 +292,10 @@ class JITFunction(KernelInterface[T]):
|
||||
# TODO: method to collect all folded args
|
||||
none_args = {param.num for param, arg in zip(self.params, args) if arg is None and not param.do_not_specialize}
|
||||
ids_of_folded_args = equal_to_1 | none_args
|
||||
return namedtuple(
|
||||
"instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"]
|
||||
)(tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args), tuple(divisible_by_8))
|
||||
return namedtuple("instance_descriptor",
|
||||
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])( #
|
||||
tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args),
|
||||
tuple(divisible_by_8))
|
||||
# return _triton.code_gen.instance_descriptor(divisible_by_16,
|
||||
# equal_to_1)
|
||||
|
||||
@@ -356,6 +359,7 @@ class JITFunction(KernelInterface[T]):
|
||||
key = str(key)
|
||||
|
||||
class LegacyCompiler:
|
||||
|
||||
def __init__(self, module, name):
|
||||
self.module = module
|
||||
self.name = name
|
||||
@@ -449,9 +453,8 @@ class JITFunction(KernelInterface[T]):
|
||||
if device_type is None:
|
||||
device_types = [self._device_of(arg) for arg in non_constexpr_arg_values]
|
||||
device_types = [_device_type for _device_type in device_types if _device_type != ""]
|
||||
device_type = self._conclude_device_type(
|
||||
device_types, [self._pinned_memory_of(arg) for arg in non_constexpr_arg_values]
|
||||
)
|
||||
device_type = self._conclude_device_type(device_types,
|
||||
[self._pinned_memory_of(arg) for arg in non_constexpr_arg_values])
|
||||
|
||||
device_backend = None
|
||||
if device_type not in ["cuda"]:
|
||||
@@ -498,7 +501,7 @@ class JITFunction(KernelInterface[T]):
|
||||
|
||||
# Kernel is not cached; we have to compile.
|
||||
if key not in self.cache[device]:
|
||||
configs = (self._get_config(*[arg.value for arg in args]),)
|
||||
configs = (self._get_config(*[arg.value for arg in args]), )
|
||||
constants = {
|
||||
arg.param.num: arg.value
|
||||
for arg in args
|
||||
@@ -510,21 +513,23 @@ class JITFunction(KernelInterface[T]):
|
||||
|
||||
# Build kernel signature -- doesn't include constexpr arguments.
|
||||
signature = {
|
||||
arg.param.num: self._type_of(self._key_of(arg.value)) for arg in args if not arg.param.is_constexpr
|
||||
arg.param.num: self._type_of(self._key_of(arg.value))
|
||||
for arg in args
|
||||
if not arg.param.is_constexpr
|
||||
}
|
||||
|
||||
if self._call_hook(
|
||||
key,
|
||||
signature,
|
||||
device,
|
||||
constants,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
num_stages,
|
||||
enable_warp_specialization,
|
||||
enable_fp_fusion,
|
||||
extern_libs,
|
||||
configs,
|
||||
key,
|
||||
signature,
|
||||
device,
|
||||
constants,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
num_stages,
|
||||
enable_warp_specialization,
|
||||
enable_fp_fusion,
|
||||
extern_libs,
|
||||
configs,
|
||||
):
|
||||
return None
|
||||
|
||||
@@ -581,7 +586,7 @@ class JITFunction(KernelInterface[T]):
|
||||
|
||||
# function source code (without decorators)
|
||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||
self.src = self.src[self.src.find("def") :]
|
||||
self.src = self.src[self.src.find("def"):]
|
||||
# cache of just-in-time compiled kernels
|
||||
self.cache = defaultdict(dict)
|
||||
self.hash = None
|
||||
@@ -734,6 +739,7 @@ class MockTensor:
|
||||
|
||||
|
||||
class TensorWrapper:
|
||||
|
||||
def __init__(self, base, dtype):
|
||||
self.dtype = dtype
|
||||
self.base = base
|
||||
|
||||
@@ -78,10 +78,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None):
|
||||
return torch.mean(torch.tensor(ret)).item()
|
||||
|
||||
|
||||
def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
|
||||
quantiles=None,
|
||||
fast_flush=True,
|
||||
return_mode="mean"):
|
||||
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean"):
|
||||
assert return_mode in ["min", "max", "mean", "median"]
|
||||
import torch
|
||||
"""
|
||||
@@ -261,6 +258,7 @@ class Benchmark:
|
||||
|
||||
|
||||
class Mark:
|
||||
|
||||
def __init__(self, fn, benchmarks):
|
||||
self.fn = fn
|
||||
self.benchmarks = benchmarks
|
||||
@@ -405,12 +403,15 @@ def get_max_tensorcore_tflops(dtype, clock_rate, backend=None, device=None):
|
||||
tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
|
||||
return tflops
|
||||
|
||||
|
||||
# create decorator that wraps test function into
|
||||
# a cuda-memcheck system call
|
||||
|
||||
|
||||
def cuda_memcheck(**target_kwargs):
|
||||
|
||||
def decorator(test_fn):
|
||||
|
||||
@functools.wraps(test_fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
import psutil
|
||||
@@ -428,7 +429,9 @@ def cuda_memcheck(**target_kwargs):
|
||||
assert "ERROR SUMMARY: 0 errors" in str(out.stdout)
|
||||
else:
|
||||
test_fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@@ -436,22 +439,18 @@ def cuda_memcheck(**target_kwargs):
|
||||
def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215):
|
||||
try:
|
||||
subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"])
|
||||
subprocess.check_output(
|
||||
[
|
||||
"nvidia-smi",
|
||||
"-i",
|
||||
"0",
|
||||
f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}",
|
||||
]
|
||||
)
|
||||
subprocess.check_output(
|
||||
[
|
||||
"nvidia-smi",
|
||||
"-i",
|
||||
"0",
|
||||
f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}",
|
||||
]
|
||||
)
|
||||
subprocess.check_output([
|
||||
"nvidia-smi",
|
||||
"-i",
|
||||
"0",
|
||||
f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}",
|
||||
])
|
||||
subprocess.check_output([
|
||||
"nvidia-smi",
|
||||
"-i",
|
||||
"0",
|
||||
f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}",
|
||||
])
|
||||
cur_sm_clock = nvsmi(["clocks.current.sm"])[0]
|
||||
cur_mem_clock = nvsmi(["clocks.current.memory"])[0]
|
||||
assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz"
|
||||
|
||||
@@ -141,8 +141,7 @@ class ExternLibrary(ABC):
|
||||
f.write(file_str)
|
||||
f.close()
|
||||
if self._format:
|
||||
subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file],
|
||||
stdout=subprocess.PIPE).communicate()
|
||||
subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file], stdout=subprocess.PIPE).communicate()
|
||||
subprocess.Popen(["isort", output_file], stdout=subprocess.PIPE).communicate()
|
||||
|
||||
|
||||
@@ -208,56 +207,36 @@ class Libdevice(ExternLibrary):
|
||||
|
||||
# Group functions together by renaming.
|
||||
renaming = {
|
||||
'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh',
|
||||
'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': 'add_rn',
|
||||
'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru',
|
||||
'dadd_rz': 'add_rz', 'fadd_rz': 'add_rz', 'asinf': 'asin',
|
||||
'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2',
|
||||
'atanhf': 'atanh', 'brevll': 'brev', 'cbrtf': 'cbrt',
|
||||
'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign',
|
||||
'cosf': 'cos', 'coshf': 'cosh', 'cospif': 'cospi',
|
||||
'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1',
|
||||
'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn',
|
||||
'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', 'ddiv_ru': 'div_ru',
|
||||
'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf',
|
||||
'erfcf': 'erfc', 'erfcinvf': 'erfcinv', 'erfcxf': 'erfcx',
|
||||
'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10',
|
||||
'exp2f': 'exp2', 'expm1f': 'expm1', 'fabsf': 'abs',
|
||||
'fabs': 'abs', 'fast_fdividef': 'fast_dividef',
|
||||
'fdimf': 'fdim', 'ffsll': 'ffs', 'floorf': 'floor',
|
||||
'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn',
|
||||
'fmaf_ru': 'fma_ru', 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod',
|
||||
'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb',
|
||||
'isinff': 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan',
|
||||
'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn',
|
||||
'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint',
|
||||
'llroundf': 'llround', 'logf': 'log', 'log10f': 'log10',
|
||||
'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb',
|
||||
'umax': 'max', 'llmax': 'max', 'ullmax': 'max', 'fmaxf': 'max',
|
||||
'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min',
|
||||
'fminf': 'min', 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd',
|
||||
'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', 'dmul_ru': 'mul_ru',
|
||||
'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz',
|
||||
'umul24': 'mul24', 'umulhi': 'mulhi', 'mul64hi': 'mulhi',
|
||||
'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': 'nextafter',
|
||||
'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf',
|
||||
'normcdfinvf': 'normcdfinv', 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow',
|
||||
'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', 'drcp_rd': 'rcp_rd',
|
||||
'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru',
|
||||
'drcp_ru': 'rcp_ru', 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz',
|
||||
'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot',
|
||||
'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d',
|
||||
'roundf': 'round', 'rsqrtf': 'rsqrt', 'frsqrt_rn': 'rsqrt_rn',
|
||||
'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit',
|
||||
'signbitd': 'signbit', 'sinf': 'sin', 'sinhf': 'sinh',
|
||||
'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd',
|
||||
'dsqrt_rd': 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn',
|
||||
'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', 'fsqrt_rz': 'sqrt_rz',
|
||||
'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd',
|
||||
'fsub_rn': 'sub_rn', 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru',
|
||||
'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz',
|
||||
'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc',
|
||||
'y0f': 'y0', 'y1f': 'y1', 'ynf': 'yn'
|
||||
'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh', 'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn':
|
||||
'add_rn', 'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru', 'dadd_rz': 'add_rz', 'fadd_rz':
|
||||
'add_rz', 'asinf': 'asin', 'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2', 'atanhf': 'atanh',
|
||||
'brevll': 'brev', 'cbrtf': 'cbrt', 'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign', 'cosf': 'cos',
|
||||
'coshf': 'cosh', 'cospif': 'cospi', 'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1',
|
||||
'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn', 'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru',
|
||||
'ddiv_ru': 'div_ru', 'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf', 'erfcf': 'erfc', 'erfcinvf':
|
||||
'erfcinv', 'erfcxf': 'erfcx', 'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10', 'exp2f': 'exp2',
|
||||
'expm1f': 'expm1', 'fabsf': 'abs', 'fabs': 'abs', 'fast_fdividef': 'fast_dividef', 'fdimf': 'fdim', 'ffsll':
|
||||
'ffs', 'floorf': 'floor', 'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn', 'fmaf_ru': 'fma_ru',
|
||||
'fmaf_rz': 'fma_rz', 'fmodf': 'fmod', 'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb', 'isinff':
|
||||
'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan', 'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn',
|
||||
'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint', 'llroundf': 'llround', 'logf': 'log', 'log10f':
|
||||
'log10', 'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb', 'umax': 'max', 'llmax': 'max', 'ullmax':
|
||||
'max', 'fmaxf': 'max', 'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min', 'fminf': 'min',
|
||||
'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd', 'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn',
|
||||
'dmul_ru': 'mul_ru', 'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz', 'umul24': 'mul24',
|
||||
'umulhi': 'mulhi', 'mul64hi': 'mulhi', 'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf':
|
||||
'nextafter', 'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf', 'normcdfinvf': 'normcdfinv',
|
||||
'popcll': 'popc', 'powif': 'pow', 'powi': 'pow', 'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd',
|
||||
'drcp_rd': 'rcp_rd', 'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru', 'drcp_ru': 'rcp_ru',
|
||||
'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz', 'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot',
|
||||
'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d', 'roundf': 'round', 'rsqrtf': 'rsqrt',
|
||||
'frsqrt_rn': 'rsqrt_rn', 'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit', 'signbitd': 'signbit',
|
||||
'sinf': 'sin', 'sinhf': 'sinh', 'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd', 'dsqrt_rd':
|
||||
'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn', 'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru',
|
||||
'fsqrt_rz': 'sqrt_rz', 'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd', 'fsub_rn': 'sub_rn',
|
||||
'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru', 'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz',
|
||||
'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc', 'y0f': 'y0', 'y1f': 'y1', 'ynf':
|
||||
'yn'
|
||||
}
|
||||
|
||||
for symbol in self._symbols.values():
|
||||
@@ -347,8 +326,7 @@ class LLVMDisassembler:
|
||||
self._ll_file = "/tmp/extern_lib.ll"
|
||||
|
||||
def disasm(self, lib_path: str) -> None:
|
||||
subprocess.Popen([self._path, lib_path, "-o", self.ll_file],
|
||||
stdout=subprocess.PIPE).communicate()
|
||||
subprocess.Popen([self._path, lib_path, "-o", self.ll_file], stdout=subprocess.PIPE).communicate()
|
||||
|
||||
@property
|
||||
def ll_file(self) -> str:
|
||||
|
||||
@@ -40,10 +40,13 @@ if __name__ == "__main__":
|
||||
|
||||
# command-line arguments
|
||||
parser = ArgumentParser(description=desc)
|
||||
parser.add_argument("path", help="Path to Python source containing desired kernel in its scope. File will be executed.")
|
||||
parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile", required=True)
|
||||
parser.add_argument("path",
|
||||
help="Path to Python source containing desired kernel in its scope. File will be executed.")
|
||||
parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile",
|
||||
required=True)
|
||||
parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel")
|
||||
parser.add_argument("--num-stages", "-ns", type=int, default=3, help="Number of stages (meta-parameter of the kernel)")
|
||||
parser.add_argument("--num-stages", "-ns", type=int, default=3,
|
||||
help="Number of stages (meta-parameter of the kernel)")
|
||||
parser.add_argument("--out-name", "-on", type=str, default=None, help="Out name for the compiled kernel")
|
||||
parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename")
|
||||
parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True)
|
||||
@@ -104,7 +107,8 @@ if __name__ == "__main__":
|
||||
config = triton.compiler.instance_descriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1)
|
||||
for i in equal_to_1:
|
||||
constexprs.update({i: 1})
|
||||
ccinfo = triton.compile(kernel, signature=signature, constants=constexprs, configs=[config], num_warps=args.num_warps, num_stages=args.num_stages)
|
||||
ccinfo = triton.compile(kernel, signature=signature, constants=constexprs, configs=[config],
|
||||
num_warps=args.num_warps, num_stages=args.num_stages)
|
||||
arg_names = []
|
||||
arg_types = []
|
||||
for i in signature.keys():
|
||||
|
||||
@@ -27,13 +27,12 @@ class KernelLinkerMeta:
|
||||
|
||||
|
||||
class HeaderParser:
|
||||
|
||||
def __init__(self) -> None:
|
||||
import re
|
||||
|
||||
# [kernel_name, c signature]
|
||||
self.linker_directives = re.compile(
|
||||
"//[\\s]*tt-linker:[\\s]*([\\w]+):(.+):(.+)"
|
||||
)
|
||||
self.linker_directives = re.compile("//[\\s]*tt-linker:[\\s]*([\\w]+):(.+):(.+)")
|
||||
# [name, hash, suffix]
|
||||
self.kernel_name = re.compile("^([\\w]+)_([\\w]+)_([\\w]+)$")
|
||||
# [(type, name)]
|
||||
@@ -153,9 +152,7 @@ void unload_{meta.orig_kernel_name}();
|
||||
# generate dispatcher function for kernels with different meta-parameter and constant values
|
||||
def make_default_algo_kernel(meta: KernelLinkerMeta) -> str:
|
||||
src = f"CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}){{\n"
|
||||
src += (
|
||||
f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n"
|
||||
)
|
||||
src += (f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n")
|
||||
src += "}\n"
|
||||
return src
|
||||
|
||||
@@ -167,28 +164,22 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -
|
||||
src += f"CUresult {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, {gen_signature(meta)});\n"
|
||||
src += "\n"
|
||||
|
||||
src += (
|
||||
f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{"
|
||||
)
|
||||
src += (f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{")
|
||||
src += "\n"
|
||||
for meta in sorted(metas, key=lambda m: -m.num_specs):
|
||||
cond_fn = (
|
||||
lambda val, hint: f"({val} % {hint} == 0)"
|
||||
if hint == 16
|
||||
else f"({val} == {hint})"
|
||||
if hint == 1
|
||||
else None
|
||||
)
|
||||
conds = " && ".join(
|
||||
[
|
||||
cond_fn(val, hint)
|
||||
for val, hint in zip(meta.arg_names, meta.sizes)
|
||||
if hint is not None
|
||||
]
|
||||
)
|
||||
src += (
|
||||
f" if ({conds})\n" if any(meta.sizes) else "if (1)\n"
|
||||
) # Edge case where no specializations hence no dispatching required
|
||||
cond_fn = ( #
|
||||
lambda val, hint: f"({val} % {hint} == 0)" #
|
||||
if hint == 16 #
|
||||
else f"({val} == {hint})" #
|
||||
if hint == 1 #
|
||||
else None)
|
||||
conds = " && ".join([ #
|
||||
cond_fn(val, hint) #
|
||||
for val, hint in zip(meta.arg_names, meta.sizes) #
|
||||
if hint is not None
|
||||
])
|
||||
src += (f" if ({conds})\n" if any(meta.sizes) else "if (1)\n"
|
||||
) # Edge case where no specializations hence no dispatching required
|
||||
arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1]
|
||||
src += f" return {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(stream, {', '.join(arg_names)});\n"
|
||||
src += "\n"
|
||||
@@ -202,9 +193,7 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -
|
||||
src += f"void {mode}_{name}() {{"
|
||||
src += "\n"
|
||||
for meta in sorted(metas, key=lambda m: -m.num_specs):
|
||||
src += (
|
||||
f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n"
|
||||
)
|
||||
src += (f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n")
|
||||
src += "}\n"
|
||||
return src
|
||||
|
||||
@@ -306,10 +295,7 @@ if __name__ == "__main__":
|
||||
fp.write(out)
|
||||
|
||||
# generate source
|
||||
defs = [
|
||||
make_kernel_hints_dispatcher(name, meta)
|
||||
for name, meta in parser.kernels.items()
|
||||
]
|
||||
defs = [make_kernel_hints_dispatcher(name, meta) for name, meta in parser.kernels.items()]
|
||||
names = [name for name in parser.kernels.keys()]
|
||||
func_pointers_def = make_func_pointers(names, meta)
|
||||
meta_const_def = make_kernel_meta_const_dispatcher(meta)
|
||||
|
||||
Reference in New Issue
Block a user