mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge commit 'cb3d79a185e40c9d8a579bea07747a8a8d157d52' into ifu-231117
Conflicts: lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp lib/Dialect/TritonGPU/IR/Dialect.cpp python/setup.py python/test/unit/language/assert_helper.py python/test/unit/operators/test_flash_attention.py python/test/unit/runtime/test_subproc.py python/triton/compiler/compiler.py python/triton/language/semantic.py python/triton/runtime/autotuner.py python/triton/runtime/jit.py python/tutorials/03-matrix-multiplication.py python/tutorials/05-layer-norm.py python/tutorials/06-fused-attention.py python/tutorials/11-grouped-gemm.py test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
@@ -45,12 +45,12 @@ __all__ = [
|
||||
"tools",
|
||||
]
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
# misc. utilities that don't fit well
|
||||
# into any specific module
|
||||
# -------------------------------------
|
||||
|
||||
|
||||
def cdiv(x: int, y: int):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import importlib
|
||||
import importlib.util
|
||||
import os
|
||||
@@ -10,8 +10,12 @@ from typing import Dict
|
||||
|
||||
from ..runtime.driver import DriverBase
|
||||
|
||||
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
TRITON_VERSION = "2.1.0"
|
||||
|
||||
|
||||
class BaseBackend:
|
||||
|
||||
def __init__(self, device_type: str) -> None:
|
||||
self.device_type = device_type
|
||||
|
||||
@@ -104,7 +108,7 @@ def get_backend(device_type: str):
|
||||
def _path_to_binary(binary: str):
|
||||
base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
|
||||
paths = [
|
||||
os.environ.get("TRITON_PTXAS_PATH", ""),
|
||||
os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
|
||||
os.path.join(base_dir, "third_party", "cuda", "bin", binary)
|
||||
]
|
||||
|
||||
@@ -132,3 +136,48 @@ def path_to_cuobjdump():
|
||||
@functools.lru_cache()
|
||||
def path_to_nvdisasm():
|
||||
return _path_to_binary("nvdisasm")
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def compute_core_version_key():
|
||||
import pkgutil
|
||||
contents = []
|
||||
# frontend
|
||||
with open(__file__, "rb") as f:
|
||||
contents += [hashlib.sha1(f.read()).hexdigest()]
|
||||
# compiler
|
||||
compiler_path = os.path.join(TRITON_PATH, 'compiler')
|
||||
for lib in pkgutil.iter_modules([compiler_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.sha1(f.read()).hexdigest()]
|
||||
# backend
|
||||
libtriton_hash = hashlib.sha1()
|
||||
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
|
||||
while True:
|
||||
chunk = f.read(1024**2)
|
||||
if not chunk:
|
||||
break
|
||||
libtriton_hash.update(chunk)
|
||||
contents.append(libtriton_hash.hexdigest())
|
||||
# language
|
||||
language_path = os.path.join(TRITON_PATH, 'language')
|
||||
for lib in pkgutil.iter_modules([language_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.sha1(f.read()).hexdigest()]
|
||||
return '-'.join(TRITON_VERSION) + '-'.join(contents)
|
||||
|
||||
|
||||
_cached_cuda_version_key = None
|
||||
|
||||
|
||||
def get_cuda_version_key():
|
||||
global _cached_cuda_version_key
|
||||
if _cached_cuda_version_key is None:
|
||||
key = compute_core_version_key()
|
||||
try:
|
||||
ptxas = path_to_ptxas()[0]
|
||||
ptxas_version = subprocess.check_output([ptxas, "--version"])
|
||||
except RuntimeError:
|
||||
ptxas_version = b"NO_PTXAS"
|
||||
_cached_cuda_version_key = key + '-' + hashlib.sha1(ptxas_version).hexdigest()
|
||||
return _cached_cuda_version_key
|
||||
|
||||
@@ -92,9 +92,15 @@ def _build(name, src, srcdir):
|
||||
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
|
||||
|
||||
if is_hip():
|
||||
ret = subprocess.check_call([cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{hip_lib_dir}", "-lamdhip64", "-o", so])
|
||||
ret = subprocess.check_call([
|
||||
cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC",
|
||||
f"-L{hip_lib_dir}", "-lamdhip64", "-o", so
|
||||
])
|
||||
else:
|
||||
cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]
|
||||
cc_cmd = [
|
||||
cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda",
|
||||
"-o", so
|
||||
]
|
||||
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
|
||||
ret = subprocess.check_call(cc_cmd)
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from .compiler import (CompiledKernel, compile, get_arch_default_num_stages,
|
||||
get_arch_default_num_warps, instance_descriptor)
|
||||
from .compiler import (CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps,
|
||||
instance_descriptor)
|
||||
from .errors import CompilationError
|
||||
|
||||
__all__ = ["compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps", "get_arch_default_num_stages"]
|
||||
__all__ = [
|
||||
"compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps",
|
||||
"get_arch_default_num_stages"
|
||||
]
|
||||
|
||||
@@ -10,8 +10,7 @@ from .._C.libtriton.triton import ir
|
||||
from ..language import constexpr, tensor
|
||||
# ideally we wouldn't need any runtime component
|
||||
from ..runtime import JITFunction
|
||||
from .errors import (CompilationError, CompileTimeAssertionFailure,
|
||||
UnsupportedLanguageConstruct)
|
||||
from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
|
||||
|
||||
|
||||
def mangle_ty(ty):
|
||||
@@ -68,7 +67,10 @@ def _check_fn_args(node, fn, args):
|
||||
if fn.noinline:
|
||||
for idx, arg in enumerate(args):
|
||||
if not _is_constexpr(arg) and not _is_triton_scalar(arg):
|
||||
raise UnsupportedLanguageConstruct(fn.src, node, f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}')
|
||||
raise UnsupportedLanguageConstruct(
|
||||
fn.src, node,
|
||||
f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}'
|
||||
)
|
||||
|
||||
|
||||
def _get_fn_file_line(fn):
|
||||
@@ -89,6 +91,7 @@ _condition_types = {bool, int, type(None)} # Python types accepted for conditio
|
||||
|
||||
|
||||
class enter_sub_region:
|
||||
|
||||
def __init__(self, generator):
|
||||
self.generator = generator
|
||||
|
||||
@@ -109,6 +112,7 @@ class enter_sub_region:
|
||||
|
||||
# Check if the given syntax node has an "early" return
|
||||
class ContainsReturnChecker(ast.NodeVisitor):
|
||||
|
||||
def __init__(self, gscope):
|
||||
self.gscope = gscope
|
||||
|
||||
@@ -199,9 +203,10 @@ class ContainsReturnChecker(ast.NodeVisitor):
|
||||
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, function_name, target,
|
||||
module=None, is_kernel=False, function_types: Optional[Dict] = None,
|
||||
debug=False, noinline=False, file_name: Optional[str] = None, begin_line=0):
|
||||
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, function_name, target, module=None,
|
||||
is_kernel=False, function_types: Optional[Dict] = None, debug=False, noinline=False,
|
||||
file_name: Optional[str] = None, begin_line=0):
|
||||
self.context = context
|
||||
self.builder = ir.builder(context)
|
||||
self.file_name = file_name
|
||||
@@ -237,8 +242,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
))
|
||||
|
||||
def _define_name_lookup(self):
|
||||
|
||||
def local_lookup(name: str, absent):
|
||||
value = self.lscope.get(name, absent) # this needs to be re-fetched from `self` every time, because it gets switched occasionally
|
||||
# this needs to be re-fetched from `self` every time, because it gets switched occasionally
|
||||
value = self.lscope.get(name, absent)
|
||||
if value is not absent and name not in self.local_defs:
|
||||
self.global_uses[name] = value
|
||||
return value
|
||||
@@ -255,8 +262,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
return name_lookup
|
||||
|
||||
def set_value(self, name: str,
|
||||
value: Union[tensor, constexpr]) -> None:
|
||||
def set_value(self, name: str, value: Union[tensor, constexpr]) -> None:
|
||||
''' This function:
|
||||
called by visit_Assign() & visit_FunctionDef() to store left value (lvalue)
|
||||
1. record local defined name (FIXME: should consider control flow)
|
||||
@@ -338,7 +344,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.visit(init_node)
|
||||
# initialize function
|
||||
visibility = "public" if self.is_kernel else "private"
|
||||
self.fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility, self.noinline)
|
||||
self.fn = self.builder.get_or_insert_function(self.module, self.function_name,
|
||||
self.prototype.to_ir(self.builder), visibility, self.noinline)
|
||||
self.module.push_back(self.fn)
|
||||
entry = self.fn.add_entry_block()
|
||||
arg_values = []
|
||||
@@ -469,12 +476,23 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
rhs = self.visit(node.right)
|
||||
method_name = self._method_name_for_bin_op.get(type(node.op))
|
||||
if method_name is None:
|
||||
raise UnsupportedLanguageConstruct(None, node, "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
||||
return self._apply_binary_method(method_name, lhs, rhs)
|
||||
|
||||
_method_name_for_bin_op: Dict[Type[ast.operator], str] = {
|
||||
ast.Add: '__add__', ast.Sub: '__sub__', ast.Mult: '__mul__', ast.Div: '__truediv__',
|
||||
ast.FloorDiv: '__floordiv__', ast.Mod: '__mod__', ast.Pow: '__pow__',
|
||||
ast.LShift: '__lshift__', ast.RShift: '__rshift__', ast.BitAnd: '__and__', ast.BitOr: '__or__', ast.BitXor: '__xor__',
|
||||
ast.Add: '__add__',
|
||||
ast.Sub: '__sub__',
|
||||
ast.Mult: '__mul__',
|
||||
ast.Div: '__truediv__',
|
||||
ast.FloorDiv: '__floordiv__',
|
||||
ast.Mod: '__mod__',
|
||||
ast.Pow: '__pow__',
|
||||
ast.LShift: '__lshift__',
|
||||
ast.RShift: '__rshift__',
|
||||
ast.BitAnd: '__and__',
|
||||
ast.BitOr: '__or__',
|
||||
ast.BitXor: '__xor__',
|
||||
}
|
||||
|
||||
def visit_then_else_blocks(self, node, liveins, then_block, else_block):
|
||||
@@ -508,7 +526,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if name in then_defs or name in else_defs:
|
||||
names.append(name)
|
||||
ret_types.append(then_defs[name].type if name in then_defs else else_defs[name].type)
|
||||
ir_ret_types.append(then_defs[name].handle.get_type() if name in then_defs else else_defs[name].handle.get_type())
|
||||
ir_ret_types.append(then_defs[name].handle.get_type() if name in
|
||||
then_defs else else_defs[name].handle.get_type())
|
||||
# variable defined in then but not in else
|
||||
if name in then_defs and name not in else_defs:
|
||||
else_defs[name] = liveins[name]
|
||||
@@ -602,8 +621,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
contains_return = ContainsReturnChecker(self.gscope).visit(node)
|
||||
if self.scf_stack and contains_return:
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node,
|
||||
"Cannot have `return` statements inside `while` or `for` statements in triton "
|
||||
None, node, "Cannot have `return` statements inside `while` or `for` statements in triton "
|
||||
"(note that this also applies to `return` statements that are inside functions "
|
||||
"transitively called from within `while`/`for` statements)")
|
||||
elif self.scf_stack or not contains_return:
|
||||
@@ -612,10 +630,13 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.visit_if_top_level(cond, node)
|
||||
else:
|
||||
cond = _unwrap_if_constexpr(cond)
|
||||
if type(cond) not in _condition_types: # not isinstance - we insist the real thing, no subclasses and no ducks
|
||||
# not isinstance - we insist the real thing, no subclasses and no ducks
|
||||
if type(cond) not in _condition_types:
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
|
||||
', '.join(_.__name__ for _ in _condition_types), type(cond).__name__))
|
||||
None, node,
|
||||
"`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
|
||||
', '.join(_.__name__ for _ in _condition_types),
|
||||
type(cond).__name__))
|
||||
if cond:
|
||||
self.visit_compound_statement(node.body)
|
||||
else:
|
||||
@@ -624,15 +645,52 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def visit_IfExp(self, node):
|
||||
cond = self.visit(node.test)
|
||||
if _is_triton_tensor(cond):
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node,
|
||||
"Triton does not support `if` expressions (ternary operators) with dynamic conditions, use `if` statements instead")
|
||||
cond = cond.to(language.int1, _builder=self.builder)
|
||||
# TODO: Deal w/ more complicated return types (e.g tuple)
|
||||
with enter_sub_region(self):
|
||||
ip, last_loc = self._get_insertion_point_and_loc()
|
||||
|
||||
then_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(then_block)
|
||||
then_val = language.core._to_tensor(self.visit(node.body), self.builder)
|
||||
then_block = self.builder.get_insertion_block()
|
||||
|
||||
else_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(else_block)
|
||||
# do not need to reset lscope since
|
||||
# ternary expressions cannot define new variables
|
||||
else_val = language.core._to_tensor(self.visit(node.orelse), self.builder)
|
||||
else_block = self.builder.get_insertion_block()
|
||||
|
||||
self._set_insertion_point_and_loc(ip, last_loc)
|
||||
|
||||
assert then_val.type == else_val.type, \
|
||||
f'ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}'
|
||||
ret_type = then_val.type
|
||||
|
||||
ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else []
|
||||
if_op = self.builder.create_if_op(ret_type_ir, cond.handle, True)
|
||||
then_block.merge_block_before(if_op.get_then_block())
|
||||
if ret_type_ir:
|
||||
self.builder.set_insertion_point_to_end(if_op.get_then_block())
|
||||
self.builder.create_yield_op([then_val.handle])
|
||||
|
||||
self.builder.set_insertion_point_to_end(if_op.get_then_block())
|
||||
else_block.merge_block_before(if_op.get_else_block())
|
||||
if ret_type_ir:
|
||||
self.builder.set_insertion_point_to_end(if_op.get_else_block())
|
||||
self.builder.create_yield_op([else_val.handle])
|
||||
return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None
|
||||
else:
|
||||
cond = _unwrap_if_constexpr(cond)
|
||||
if type(cond) not in _condition_types: # not isinstance - we insist the real thing, no subclasses and no ducks
|
||||
|
||||
# not isinstance - we insist the real thing, no subclasses and no ducks
|
||||
if type(cond) not in _condition_types:
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
|
||||
', '.join(_.__name__ for _ in _condition_types), type(cond).__name__))
|
||||
None, node,
|
||||
"`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
|
||||
', '.join(_.__name__ for _ in _condition_types),
|
||||
type(cond).__name__))
|
||||
if cond:
|
||||
return self.visit(node.body)
|
||||
else:
|
||||
@@ -654,8 +712,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
return constexpr(lhs_value is not rhs_value)
|
||||
method_name = self._method_name_for_comp_op.get(type(node.ops[0]))
|
||||
if method_name is None:
|
||||
raise UnsupportedLanguageConstruct(None, node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__))
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__))
|
||||
return self._apply_binary_method(method_name, lhs, rhs)
|
||||
|
||||
_method_name_for_comp_op: Dict[Type[ast.cmpop], str] = {
|
||||
ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__'
|
||||
}
|
||||
@@ -664,11 +724,15 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
op = self.visit(node.operand)
|
||||
fn = self._method_name_for_unary_op.get(type(node.op))
|
||||
if fn is None:
|
||||
raise UnsupportedLanguageConstruct(None, node, "AST unary operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "AST unary operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
||||
if _is_triton_tensor(op):
|
||||
return getattr(op, fn)(_builder=self.builder)
|
||||
return getattr(op, fn)()
|
||||
_method_name_for_unary_op: Dict[Type[ast.unaryop], str] = {ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__'}
|
||||
|
||||
_method_name_for_unary_op: Dict[Type[ast.unaryop], str] = {
|
||||
ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__'
|
||||
}
|
||||
|
||||
def visit_While(self, node):
|
||||
with enter_sub_region(self) as sr:
|
||||
@@ -763,9 +827,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
iter_args = [self.visit(arg) for arg in node.iter.args]
|
||||
if IteratorClass == language.static_range:
|
||||
iterator = IteratorClass(*iter_args)
|
||||
static_range = range(iterator.start.value,
|
||||
iterator.end.value,
|
||||
iterator.step.value)
|
||||
static_range = range(iterator.start.value, iterator.end.value, iterator.step.value)
|
||||
for i in static_range:
|
||||
self.lscope[node.target.id] = constexpr(i)
|
||||
self.visit_compound_statement(node.body)
|
||||
@@ -902,8 +964,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def call_JitFunction(self, fn: JITFunction, args, kwargs):
|
||||
args = inspect.getcallargs(fn.fn, *args, **kwargs)
|
||||
args = [args[name] for name in fn.arg_names]
|
||||
args = [arg if _is_triton_tensor(arg)
|
||||
else constexpr(arg) for arg in args]
|
||||
args = [arg if _is_triton_tensor(arg) else constexpr(arg) for arg in args]
|
||||
# generate function def
|
||||
attributes = dict()
|
||||
constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)]
|
||||
@@ -921,8 +982,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
debug = self.debug if fn.debug is None else fn.debug
|
||||
file_name, begin_line = _get_fn_file_line(fn)
|
||||
generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module,
|
||||
function_name=fn_name, function_types=self.function_ret_types, debug=debug, noinline=fn.noinline,
|
||||
file_name=file_name, begin_line=begin_line, target=self.builder.target)
|
||||
function_name=fn_name, function_types=self.function_ret_types, debug=debug,
|
||||
noinline=fn.noinline, file_name=file_name, begin_line=begin_line,
|
||||
target=self.builder.target)
|
||||
generator.visit(fn.parse())
|
||||
callee_ret_type = generator.last_ret_type
|
||||
self.function_ret_types[fn_name] = callee_ret_type
|
||||
@@ -950,7 +1012,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
kws = dict(self.visit(keyword) for keyword in node.keywords)
|
||||
args = [self.visit(arg) for arg in node.args]
|
||||
if fn is language.core.device_assert: # TODO: this should not be so hardcoded
|
||||
if fn is language.core.device_assert: # TODO: this should not be so hardcoded
|
||||
if not self.debug:
|
||||
return
|
||||
if isinstance(fn, JITFunction):
|
||||
@@ -971,16 +1033,21 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
def visit_BoolOp(self, node: ast.BoolOp):
|
||||
if len(node.values) != 2:
|
||||
raise UnsupportedLanguageConstruct(None, node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.")
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node,
|
||||
"chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.")
|
||||
lhs = self.visit(node.values[0])
|
||||
rhs = self.visit(node.values[1])
|
||||
method_name = self._method_name_for_bool_op.get(type(node.op))
|
||||
if method_name is None:
|
||||
raise UnsupportedLanguageConstruct(None, node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
||||
return self._apply_binary_method(method_name, lhs, rhs)
|
||||
|
||||
_method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'}
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
|
||||
def visit_NameConstant(self, node):
|
||||
return constexpr(node.value)
|
||||
|
||||
@@ -1013,7 +1080,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
evaluated = self.visit(value.value)
|
||||
if not _is_constexpr(evaluated):
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type " + str(type(evaluated)))
|
||||
None, node,
|
||||
"Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type "
|
||||
+ str(type(evaluated)))
|
||||
values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value)
|
||||
else:
|
||||
raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value)))
|
||||
@@ -1055,7 +1124,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
passed = _unwrap_if_constexpr(self.visit(node.args[0]))
|
||||
if not isinstance(passed, bool):
|
||||
raise NotImplementedError("Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values")
|
||||
raise NotImplementedError(
|
||||
"Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values"
|
||||
)
|
||||
if not passed:
|
||||
if arg_count == 1:
|
||||
message = ""
|
||||
@@ -1144,10 +1215,9 @@ def ast_to_ttir(fn, signature, specialization, constants, debug, target):
|
||||
file_name, begin_line = _get_fn_file_line(fn)
|
||||
|
||||
prototype = language.function_type([], arg_types)
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants,
|
||||
function_name=function_name, attributes=new_attrs,
|
||||
is_kernel=True, debug=debug, file_name=file_name, begin_line=begin_line,
|
||||
target=target)
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name,
|
||||
attributes=new_attrs, is_kernel=True, debug=debug, file_name=file_name,
|
||||
begin_line=begin_line, target=target)
|
||||
try:
|
||||
generator.visit(fn.parse())
|
||||
except CompilationError as e:
|
||||
|
||||
@@ -11,25 +11,21 @@ from typing import Any
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs,
|
||||
compile_ptx_to_cubin, get_env_vars, get_num_warps,
|
||||
get_shared_memory_size, ir, runtime,
|
||||
translate_llvmir_to_ptx,
|
||||
from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs, compile_ptx_to_cubin, get_env_vars,
|
||||
get_num_warps, get_shared_memory_size, ir, runtime, translate_llvmir_to_ptx,
|
||||
translate_triton_gpu_to_llvmir)
|
||||
from ..common.backend import get_backend, path_to_ptxas
|
||||
from ..common.backend import get_backend, get_cuda_version_key, path_to_ptxas
|
||||
from ..common.build import is_hip
|
||||
# from ..runtime import driver, jit, JITFunction
|
||||
# TODO: runtime.errors
|
||||
from ..runtime.autotuner import OutOfResources
|
||||
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
|
||||
from ..runtime.driver import driver
|
||||
from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device,
|
||||
get_device_capability, version_key)
|
||||
from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device, get_device_capability)
|
||||
from ..tools.disasm import get_sass
|
||||
from .code_generator import ast_to_ttir
|
||||
from .make_launcher import make_stub
|
||||
from .utils import (InfoFromBackendForTensorMap, TensorMapManager,
|
||||
get_ids_of_tensormaps, parse_tma_info)
|
||||
from .utils import (InfoFromBackendForTensorMap, TensorMapManager, get_ids_of_tensormaps, parse_tma_info)
|
||||
|
||||
CUDA_DEFAULT_WARP_SIZE = 32
|
||||
|
||||
@@ -45,6 +41,7 @@ def _is_cuda(target):
|
||||
|
||||
|
||||
class LazyDict(dict):
|
||||
|
||||
def __getitem__(self, key):
|
||||
val = dict.__getitem__(self, key)
|
||||
if callable(val):
|
||||
@@ -103,8 +100,13 @@ def ttir_to_ttgir(mod, num_warps, warpsize, num_ctas, target):
|
||||
return mod
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target,
|
||||
cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, matrix_inst_type):
|
||||
=======
|
||||
def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization,
|
||||
enable_persistent, optimize_epilogue):
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
is_cuda = _is_cuda(target)
|
||||
if is_cuda:
|
||||
capability = target.capability
|
||||
@@ -128,9 +130,13 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target,
|
||||
if optimize_epilogue:
|
||||
pm.add_tritongpu_optimize_epilogue_pass()
|
||||
pm.add_tritongpu_optimize_dot_operands_pass()
|
||||
<<<<<<< HEAD
|
||||
if num_stages == 0 and is_hip() and gpu_matrix_core_version() != 0:
|
||||
pm.add_tritongpu_stream_pipeline_pass()
|
||||
pm.add_canonicalizer_pass()
|
||||
=======
|
||||
pm.add_cse_pass()
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
ws_enabled = False
|
||||
# `num_warps` does not mean the total number of warps of a CTA when
|
||||
# warp specialization is enabled.
|
||||
@@ -174,6 +180,8 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target,
|
||||
if is_cuda and capability // 10 >= 9:
|
||||
pm.add_tritongpu_fence_insertion_pass()
|
||||
pm.add_tritongpu_ws_fixup_missing_attrs_pass()
|
||||
pm.add_tritongpu_optimize_thread_locality_pass()
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
@@ -197,6 +205,7 @@ def ttgir_to_llir(mod, extern_libs, target, tma_infos, waves_per_eu=0):
|
||||
|
||||
# PTX translation
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def ptx_get_version(cuda_version) -> int:
|
||||
'''
|
||||
@@ -261,7 +270,11 @@ def convert_type_repr(x):
|
||||
return x
|
||||
|
||||
|
||||
def make_hash(fn, target, env_vars, **kwargs):
|
||||
def make_hash(fn, target, env_vars, device_backend, **kwargs):
|
||||
if device_backend is None:
|
||||
version_key = get_cuda_version_key()
|
||||
else:
|
||||
version_key = device_backend.get_version_key()
|
||||
if isinstance(fn, JITFunction):
|
||||
configs = kwargs["configs"]
|
||||
signature = kwargs["signature"]
|
||||
@@ -275,16 +288,21 @@ def make_hash(fn, target, env_vars, **kwargs):
|
||||
enable_persistent = kwargs.get("enable_persistent", False)
|
||||
debug = kwargs.get("debug", False)
|
||||
# Get unique key for the compiled code
|
||||
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1), sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8))
|
||||
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1),
|
||||
sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8))
|
||||
configs_key = [get_conf_key(conf) for conf in configs]
|
||||
env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())]
|
||||
<<<<<<< HEAD
|
||||
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{matrix_instr_nonkdim}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}"
|
||||
=======
|
||||
key = f"{fn.cache_key}-{version_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}"
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
return hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
assert isinstance(fn, str)
|
||||
ignore_version = kwargs.get('ignore_version', False)
|
||||
if (ignore_version):
|
||||
return hashlib.md5((Path(fn).read_text()).encode("utf-8")).hexdigest()
|
||||
return hashlib.md5((Path(fn).read_text() + version_key()).encode("utf-8")).hexdigest()
|
||||
return hashlib.md5((Path(fn).read_text() + version_key).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
|
||||
@@ -321,12 +339,14 @@ else:
|
||||
|
||||
|
||||
def _get_jsonable_constants(constants):
|
||||
|
||||
def _is_jsonable(x):
|
||||
try:
|
||||
json.dumps(x)
|
||||
return True
|
||||
except (TypeError, OverflowError):
|
||||
return False
|
||||
|
||||
serialized_constants = {}
|
||||
for constant in constants:
|
||||
if _is_jsonable(constants[constant]):
|
||||
@@ -341,7 +361,9 @@ def parse_mlir_module(path, context):
|
||||
return module
|
||||
|
||||
|
||||
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"], defaults=[set(), set(), set(), set()])
|
||||
instance_descriptor = namedtuple("instance_descriptor",
|
||||
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"],
|
||||
defaults=[set(), set(), set(), set()])
|
||||
|
||||
|
||||
def is_hip():
|
||||
@@ -385,10 +407,16 @@ def get_arch_default_num_stages(device_type, capability=None):
|
||||
|
||||
|
||||
def add_cuda_stages(target, extern_libs, stages):
|
||||
<<<<<<< HEAD
|
||||
stages["ptx"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: llir_to_ptx(src, target))
|
||||
stages["cubin"] = (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ptx_to_cubin(src, target))
|
||||
=======
|
||||
|
||||
stages["ptx"] = (lambda path: Path(path).read_text(), lambda src: llir_to_ptx(src, target))
|
||||
stages["cubin"] = (lambda path: Path(path).read_bytes(), lambda src: ptx_to_cubin(src, target))
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
|
||||
|
||||
def compile(fn, **kwargs):
|
||||
@@ -434,7 +462,8 @@ def compile(fn, **kwargs):
|
||||
# build architecture descriptor
|
||||
if device_type == "cuda":
|
||||
_device_backend = get_backend(device_type)
|
||||
target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps, enable_fp_fusion=enable_fp_fusion)
|
||||
target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps,
|
||||
enable_fp_fusion=enable_fp_fusion)
|
||||
else:
|
||||
_device_backend = get_backend(device_type)
|
||||
assert _device_backend
|
||||
@@ -443,11 +472,12 @@ def compile(fn, **kwargs):
|
||||
# build compilation stages
|
||||
stages = dict()
|
||||
stages["ast"] = (lambda path: fn, None)
|
||||
stages["ttir"] = (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target))
|
||||
stages["ttir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttir(
|
||||
ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target))
|
||||
if is_cuda:
|
||||
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
|
||||
stages["ttgir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttgir(
|
||||
ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info,
|
||||
enable_warp_specialization, enable_persistent, optimize_epilogue))
|
||||
stages["llir"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, target, tma_infos))
|
||||
add_cuda_stages(target, extern_libs, stages)
|
||||
@@ -507,18 +537,21 @@ def compile(fn, **kwargs):
|
||||
if ir_name == 'ttgir':
|
||||
num_warps_matches = re.findall(ttgir_num_warps_pattern, src)
|
||||
assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps"
|
||||
assert "num_warps" not in kwargs or int(num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile"
|
||||
assert "num_warps" not in kwargs or int(
|
||||
num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile"
|
||||
num_warps = int(num_warps_matches[0])
|
||||
param_tys = [convert_type_repr(ty) for ty in types]
|
||||
signature = {k: v for k, v in enumerate(param_tys)}
|
||||
first_stage = list(stages.keys()).index(ir_name)
|
||||
|
||||
# create cache manager
|
||||
fn_cache_manager = get_cache_manager(make_hash(fn, target, get_env_vars(), **kwargs))
|
||||
fn_cache_manager = get_cache_manager(make_hash(fn, target, get_env_vars(), _device_backend, **kwargs))
|
||||
# managers used to dump and override IR for debugging
|
||||
enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1"
|
||||
fn_override_manager = get_override_manager(make_hash(fn, target, get_env_vars(), **kwargs, ignore_version=True))
|
||||
fn_dump_manager = get_dump_manager(make_hash(fn, target, get_env_vars(), **kwargs, ignore_version=True))
|
||||
fn_override_manager = get_override_manager(
|
||||
make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True))
|
||||
fn_dump_manager = get_dump_manager(
|
||||
make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True))
|
||||
|
||||
# determine name and extension type of provided function
|
||||
if isinstance(fn, JITFunction):
|
||||
@@ -531,9 +564,7 @@ def compile(fn, **kwargs):
|
||||
metadata_filename = f"{name}.json"
|
||||
|
||||
# The group is addressed by the metadata
|
||||
metadata_group = fn_cache_manager.get_group(
|
||||
metadata_filename
|
||||
) or {}
|
||||
metadata_group = fn_cache_manager.get_group(metadata_filename) or {}
|
||||
|
||||
metadata_path = metadata_group.get(metadata_filename)
|
||||
|
||||
@@ -541,9 +572,9 @@ def compile(fn, **kwargs):
|
||||
with open(metadata_path) as f:
|
||||
metadata = json.load(f)
|
||||
if 'tensormaps_info' in metadata:
|
||||
metadata['tensormaps_info'] = [
|
||||
InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']]
|
||||
metadata['tensormaps_info'] = [InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']]
|
||||
else:
|
||||
<<<<<<< HEAD
|
||||
metadata = {"num_warps": num_warps,
|
||||
"warp_size": warp_size,
|
||||
"num_ctas": num_ctas,
|
||||
@@ -555,6 +586,18 @@ def compile(fn, **kwargs):
|
||||
"constants": _get_jsonable_constants(constants),
|
||||
"debug": debug,
|
||||
"target": target, }
|
||||
=======
|
||||
metadata = {
|
||||
"num_warps": num_warps,
|
||||
"num_ctas": num_ctas,
|
||||
"num_stages": num_stages,
|
||||
"enable_warp_specialization": enable_warp_specialization,
|
||||
"enable_persistent": enable_persistent,
|
||||
"constants": _get_jsonable_constants(constants),
|
||||
"debug": debug,
|
||||
"target": target,
|
||||
}
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
metadata.update(get_env_vars())
|
||||
if ext == "ptx":
|
||||
assert "shared" in kwargs, "ptx compilation must provide shared memory size"
|
||||
@@ -626,10 +669,7 @@ def compile(fn, **kwargs):
|
||||
|
||||
ids_of_folded_args = tuple([int(k) for k in configs[0].ids_of_folded_args]) if isinstance(fn, JITFunction) else ()
|
||||
if "clusterDims" not in metadata:
|
||||
metadata["clusterDims"] = [
|
||||
cluster_info.clusterDimX,
|
||||
cluster_info.clusterDimY,
|
||||
cluster_info.clusterDimZ]
|
||||
metadata["clusterDims"] = [cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ]
|
||||
|
||||
if len(tma_infos) > 0:
|
||||
metadata["tensormaps_info"] = parse_tma_info(tma_infos, ids_of_folded_args)
|
||||
@@ -643,7 +683,10 @@ def compile(fn, **kwargs):
|
||||
fn.tensormaps_info = metadata["tensormaps_info"]
|
||||
|
||||
ids_of_const_exprs = tuple(fn.constexprs) if isinstance(fn, JITFunction) else ()
|
||||
ids = {"ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs": ids_of_const_exprs}
|
||||
ids = {
|
||||
"ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs":
|
||||
ids_of_const_exprs
|
||||
}
|
||||
# cache manager
|
||||
if is_cuda:
|
||||
so_path = make_stub(name, signature, constants, ids, enable_warp_specialization=enable_warp_specialization)
|
||||
@@ -651,7 +694,8 @@ def compile(fn, **kwargs):
|
||||
so_path = _device_backend.make_launcher_stub(name, signature, constants, ids)
|
||||
# write-back metadata, if it didn't come from the cache
|
||||
if metadata_path is None:
|
||||
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, binary=False)
|
||||
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename,
|
||||
binary=False)
|
||||
fn_cache_manager.put_group(metadata_filename, metadata_group)
|
||||
|
||||
# return handle to compiled kernel
|
||||
@@ -701,10 +745,7 @@ class CompiledKernel:
|
||||
|
||||
if self.device_type in ["cuda"]:
|
||||
device = get_current_device()
|
||||
bin_path = {
|
||||
driver.HIP: "hsaco_path",
|
||||
driver.CUDA: "cubin"
|
||||
}[driver.backend]
|
||||
bin_path = {driver.HIP: "hsaco_path", driver.CUDA: "cubin"}[driver.backend]
|
||||
max_shared = driver.utils.get_device_properties(device)["max_shared_mem"]
|
||||
fn_load_binary = driver.utils.load_binary
|
||||
else:
|
||||
@@ -752,4 +793,5 @@ class CompiledKernel:
|
||||
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.num_ctas, self.clusterDims[0],
|
||||
self.clusterDims[1], self.clusterDims[2], self.shared, stream, self.cu_function,
|
||||
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand)
|
||||
|
||||
return runner
|
||||
|
||||
@@ -3,9 +3,9 @@ import os
|
||||
import tempfile
|
||||
|
||||
from ..common import _build
|
||||
from ..common.backend import get_cuda_version_key
|
||||
from ..common.build import is_hip
|
||||
from ..runtime.cache import get_cache_manager
|
||||
from ..runtime.jit import version_key
|
||||
from .utils import generate_cu_signature
|
||||
|
||||
# ----- stub --------
|
||||
@@ -23,7 +23,7 @@ def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
|
||||
|
||||
def make_stub(name, signature, constants, ids, **kwargs):
|
||||
# name of files that are cached
|
||||
so_cache_key = make_so_cache_key(version_key(), signature, constants, ids, **kwargs)
|
||||
so_cache_key = make_so_cache_key(get_cuda_version_key(), signature, constants, ids, **kwargs)
|
||||
so_cache_manager = get_cache_manager(so_cache_key)
|
||||
so_name = f"{name}.so"
|
||||
# retrieve stub from cache if it exists
|
||||
@@ -40,6 +40,7 @@ def make_stub(name, signature, constants, ids, **kwargs):
|
||||
else:
|
||||
return cache_path
|
||||
|
||||
|
||||
# ----- source code generation --------
|
||||
|
||||
|
||||
@@ -100,7 +101,10 @@ def generate_launcher(constants, signature, ids):
|
||||
|
||||
# generate glue code
|
||||
folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']]
|
||||
params = [i for i in signature.keys() if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs)]
|
||||
params = [
|
||||
i for i in signature.keys()
|
||||
if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs)
|
||||
]
|
||||
src = f"""
|
||||
#include \"cuda.h\"
|
||||
#include <stdbool.h>
|
||||
|
||||
@@ -158,19 +158,21 @@ class InfoFromBackendForTensorMap:
|
||||
|
||||
# dtype:cuda.CUtensorMapDataType | int
|
||||
def bytes_from_type(self, dtype):
|
||||
return {driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT8"]: 1,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT16"]: 2,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT64"]: 8,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT64"]: 8,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT16"]: 2,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT64"]: 8,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_BFLOAT16"]: 2,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ"]: 4}[dtype]
|
||||
return {
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT8"]: 1,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT16"]: 2,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT64"]: 8,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT64"]: 8,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT16"]: 2,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT64"]: 8,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_BFLOAT16"]: 2,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32"]: 4,
|
||||
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ"]: 4
|
||||
}[dtype]
|
||||
|
||||
def getTensorMapDataType(self):
|
||||
return self.tensorDataType
|
||||
@@ -259,22 +261,29 @@ class InfoFromBackendForTensorMap:
|
||||
self.getInterleave(),
|
||||
self.getSwizzle(),
|
||||
self.getL2Promotion(),
|
||||
self.getOobFill()
|
||||
self.getOobFill(),
|
||||
)
|
||||
|
||||
# make hashable to use as partial key in cache
|
||||
def __hash__(self):
|
||||
return hash((self.ids_of_folded_args, self.globalAddressArgIdx, tuple(self.globalDimsArgIdx), tuple(self.globalStridesArgIdx), self.tensorDataType,
|
||||
self.tensorRank, tuple(self.boxDims), tuple(self.elementStrides), self.interleave, self.swizzle, self.l2Promotion, self.oobFill))
|
||||
return hash((self.ids_of_folded_args, self.globalAddressArgIdx, tuple(self.globalDimsArgIdx),
|
||||
tuple(self.globalStridesArgIdx), self.tensorDataType, self.tensorRank, tuple(self.boxDims),
|
||||
tuple(self.elementStrides), self.interleave, self.swizzle, self.l2Promotion, self.oobFill))
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
return (self.ids_of_folded_args, self.globalAddressArgIdx, self.globalDimsArgIdx, self.globalStridesArgIdx, self.tensorDataType, self.tensorRank, self.boxDims, self.elementStrides, self.interleave, self.swizzle, self.l2Promotion, self.oobFill) == (
|
||||
other.ids_of_folded_args, other.globalAddressArgIdx, other.globalDimsArgIdx, other.globalStridesArgIdx, other.tensorDataType, other.tensorRank, other.boxDims, other.elementStrides, other.interleave, other.swizzle, other.l2Promotion, other.oobFill)
|
||||
return (self.ids_of_folded_args, self.globalAddressArgIdx, self.globalDimsArgIdx, self.globalStridesArgIdx,
|
||||
self.tensorDataType, self.tensorRank, self.boxDims, self.elementStrides, self.interleave, self.swizzle,
|
||||
self.l2Promotion,
|
||||
self.oobFill) == (other.ids_of_folded_args, other.globalAddressArgIdx, other.globalDimsArgIdx,
|
||||
other.globalStridesArgIdx, other.tensorDataType, other.tensorRank, other.boxDims,
|
||||
other.elementStrides, other.interleave, other.swizzle, other.l2Promotion,
|
||||
other.oobFill)
|
||||
|
||||
|
||||
class TensorMapManager:
|
||||
|
||||
def __init__(self):
|
||||
self.tensormaps_device = {}
|
||||
|
||||
@@ -286,8 +295,7 @@ class TensorMapManager:
|
||||
t_tensormap = e.tensormap(args)
|
||||
TENSORMAP_SIZE_IN_BYTES = 128
|
||||
t_tensormap_device = driver.utils.cuMemAlloc(TENSORMAP_SIZE_IN_BYTES)
|
||||
driver.utils.cuMemcpyHtoD(
|
||||
t_tensormap_device, t_tensormap, TENSORMAP_SIZE_IN_BYTES)
|
||||
driver.utils.cuMemcpyHtoD(t_tensormap_device, t_tensormap, TENSORMAP_SIZE_IN_BYTES)
|
||||
self.tensormaps_device[key] = t_tensormap_device
|
||||
return int(self.tensormaps_device[key])
|
||||
|
||||
|
||||
@@ -111,7 +111,6 @@ from .random import (
|
||||
uint32_to_uniform_float,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TRITON_MAX_TENSOR_NUMEL",
|
||||
"abs",
|
||||
|
||||
@@ -22,10 +22,8 @@ def builtin(fn: T) -> T:
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if "_builder" not in kwargs or kwargs["_builder"] is None:
|
||||
raise ValueError(
|
||||
"Did you forget to add @triton.jit ? "
|
||||
"(`_builder` argument must be provided outside of JIT functions.)"
|
||||
)
|
||||
raise ValueError("Did you forget to add @triton.jit ? "
|
||||
"(`_builder` argument must be provided outside of JIT functions.)")
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
setattr(wrapper, TRITON_BUILTIN, True)
|
||||
@@ -54,7 +52,7 @@ def _to_tensor(x, builder):
|
||||
else:
|
||||
raise RuntimeError(f'Nonrepresentable integer {x}.')
|
||||
elif isinstance(x, float):
|
||||
min_float32 = 2 ** -126
|
||||
min_float32 = 2**-126
|
||||
max_float32 = (2 - 2**-23) * 2**127
|
||||
abs_x = __builtins__['abs'](x)
|
||||
if abs_x == float("inf") or\
|
||||
@@ -243,7 +241,7 @@ class dtype:
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.name,))
|
||||
return hash((self.name, ))
|
||||
|
||||
@property
|
||||
def scalar(self):
|
||||
@@ -297,6 +295,7 @@ class dtype:
|
||||
|
||||
|
||||
class pointer_type(dtype):
|
||||
|
||||
def __init__(self, element_ty: dtype, address_space: int = 1):
|
||||
if not isinstance(element_ty, dtype):
|
||||
raise TypeError('element_ty is a {type(element_ty).__name__}.')
|
||||
@@ -331,6 +330,7 @@ class pointer_type(dtype):
|
||||
|
||||
|
||||
class block_type(dtype):
|
||||
|
||||
def __init__(self, element_ty: dtype, shape: List):
|
||||
self.element_ty = element_ty
|
||||
|
||||
@@ -381,6 +381,7 @@ class block_type(dtype):
|
||||
|
||||
|
||||
class function_type(dtype):
|
||||
|
||||
def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None:
|
||||
self.ret_types = ret_types
|
||||
self.param_types = param_types
|
||||
@@ -531,7 +532,7 @@ class constexpr:
|
||||
return constexpr(~self.value)
|
||||
|
||||
def __pow__(self, other):
|
||||
return constexpr(self.value ** other.value)
|
||||
return constexpr(self.value**other.value)
|
||||
|
||||
def __rshift__(self, other):
|
||||
return constexpr(self.value >> other.value)
|
||||
@@ -547,6 +548,7 @@ class constexpr:
|
||||
|
||||
|
||||
class tensor:
|
||||
|
||||
def __init__(self, handle, type: dtype):
|
||||
# IR handle
|
||||
self.handle = handle
|
||||
@@ -740,11 +742,21 @@ class tensor:
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.equal(self, other, _builder)
|
||||
|
||||
@builtin
|
||||
def __req__(self, other, _builder=None):
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.equal(other, self, _builder)
|
||||
|
||||
@builtin
|
||||
def __ne__(self, other, _builder=None):
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.not_equal(self, other, _builder)
|
||||
|
||||
@builtin
|
||||
def __rne__(self, other, _builder=None):
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.not_equal(other, self, _builder)
|
||||
|
||||
@builtin
|
||||
def logical_and(self, other, _builder=None):
|
||||
other = _to_tensor(other, _builder)
|
||||
@@ -1023,6 +1035,7 @@ def expand_dims(input, axis, _builder=None):
|
||||
ret = semantic.expand_dims(ret, a, _builder)
|
||||
return ret
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Linear Algebra
|
||||
# -----------------------
|
||||
@@ -1171,6 +1184,7 @@ def advance(base: tensor, offsets, _builder=None):
|
||||
"""
|
||||
return semantic.advance(base, offsets, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Atomic Memory Operations
|
||||
# -----------------------
|
||||
@@ -1196,6 +1210,9 @@ def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]:
|
||||
:param sem: Memory semantics to use ("ACQUIRE_RELEASE" (default),
|
||||
"ACQUIRE", "RELEASE", or "RELAXED")
|
||||
:type sem: str
|
||||
:param scope: Scope of threads that observe synchronizing effect of the
|
||||
atomic operation ("GPU" (default), "CTA", or "SYSTEM")
|
||||
:type scope: str
|
||||
"""
|
||||
func.__doc__ = docstr
|
||||
return func
|
||||
@@ -1205,73 +1222,82 @@ def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]:
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("compare-and-swap", has_cmp=True)
|
||||
def atomic_cas(pointer, cmp, val, sem=None, _builder=None):
|
||||
def atomic_cas(pointer, cmp, val, sem=None, scope=None, _builder=None):
|
||||
cmp = _to_tensor(cmp, _builder)
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_cas(pointer, cmp, val, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_cas(pointer, cmp, val, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("exchange")
|
||||
def atomic_xchg(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_xchg(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_xchg(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("add")
|
||||
def atomic_add(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_add(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_add(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_add(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("max")
|
||||
def atomic_max(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_max(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_max(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_max(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("min")
|
||||
def atomic_min(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_min(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_min(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_min(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("logical and")
|
||||
def atomic_and(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_and(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_and(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_and(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("logical or")
|
||||
def atomic_or(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_or(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_or(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_or(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("logical xor")
|
||||
def atomic_xor(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_xor(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_xor(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Conditioning
|
||||
# -----------------------
|
||||
|
||||
|
||||
@builtin
|
||||
def where(condition, x, y, _builder=None):
|
||||
"""
|
||||
@@ -1299,6 +1325,7 @@ def where(condition, x, y, _builder=None):
|
||||
# Math
|
||||
# -----------------------
|
||||
|
||||
|
||||
@builtin
|
||||
def umulhi(x, y, _builder=None):
|
||||
"""
|
||||
@@ -1392,6 +1419,7 @@ def abs(x, _builder=None):
|
||||
# Reductions
|
||||
# -----------------------
|
||||
|
||||
|
||||
def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]:
|
||||
|
||||
def _decorator(func: T) -> T:
|
||||
@@ -1430,8 +1458,7 @@ def reduce(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
|
||||
"""
|
||||
if isinstance(input, tensor):
|
||||
return reduce((input,), axis, combine_fn,
|
||||
_builder=_builder, _generator=_generator)[0]
|
||||
return reduce((input, ), axis, combine_fn, _builder=_builder, _generator=_generator)[0]
|
||||
|
||||
def make_combine_region(reduce_op):
|
||||
in_scalar_tys = [t.type.scalar for t in input]
|
||||
@@ -1441,14 +1468,14 @@ def reduce(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
with _insertion_guard(_builder):
|
||||
param_types = [ty.to_ir(_builder) for ty in prototype.param_types]
|
||||
block = _builder.create_block_with_parent(region, param_types)
|
||||
args = [tensor(block.arg(i), ty)
|
||||
for i, ty in enumerate(prototype.param_types)]
|
||||
args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)]
|
||||
results = _generator.call_JitFunction(combine_fn, args, kwargs={})
|
||||
if isinstance(results, tensor):
|
||||
handles = [results.handle]
|
||||
else:
|
||||
handles = [r.handle for r in results]
|
||||
_builder.create_reduce_ret(*handles)
|
||||
|
||||
if axis is not None:
|
||||
axis = _constexpr_to_value(axis)
|
||||
return semantic.reduction(input, axis, make_combine_region, _builder)
|
||||
@@ -1483,8 +1510,7 @@ def _reduce_with_indices(input, axis, combine_fn, _builder=None, _generator=None
|
||||
index = expand_dims(index, axes_to_expand, _builder=_builder)
|
||||
index = broadcast_to(index, input.shape, _builder=_builder)
|
||||
|
||||
rvalue, rindices = reduce((input, index), axis, combine_fn,
|
||||
_builder=_builder, _generator=_generator)
|
||||
rvalue, rindices = reduce((input, index), axis, combine_fn, _builder=_builder, _generator=_generator)
|
||||
return rvalue, rindices
|
||||
|
||||
|
||||
@@ -1492,6 +1518,7 @@ def _reduce_with_indices(input, axis, combine_fn, _builder=None, _generator=None
|
||||
# Scans
|
||||
# -----------------------
|
||||
|
||||
|
||||
def _add_scan_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]:
|
||||
|
||||
def _decorator(func: T) -> T:
|
||||
@@ -1516,8 +1543,7 @@ def associative_scan(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
|
||||
"""
|
||||
if isinstance(input, tensor):
|
||||
return associative_scan((input,), axis, combine_fn,
|
||||
_builder=_builder, _generator=_generator)[0]
|
||||
return associative_scan((input, ), axis, combine_fn, _builder=_builder, _generator=_generator)[0]
|
||||
|
||||
def make_combine_region(scan_op):
|
||||
in_scalar_tys = [t.type.scalar for t in input]
|
||||
@@ -1527,17 +1553,18 @@ def associative_scan(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
with _insertion_guard(_builder):
|
||||
param_types = [ty.to_ir(_builder) for ty in prototype.param_types]
|
||||
block = _builder.create_block_with_parent(region, param_types)
|
||||
args = [tensor(block.arg(i), ty)
|
||||
for i, ty in enumerate(prototype.param_types)]
|
||||
args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)]
|
||||
results = _generator.call_JitFunction(combine_fn, args, kwargs={})
|
||||
if isinstance(results, tensor):
|
||||
handles = [results.handle]
|
||||
else:
|
||||
handles = [r.handle for r in results]
|
||||
_builder.create_scan_ret(*handles)
|
||||
|
||||
axis = _constexpr_to_value(axis)
|
||||
return semantic.associative_scan(input, axis, make_combine_region, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Compiler Hint Ops
|
||||
# -----------------------
|
||||
@@ -1600,6 +1627,8 @@ def max_constancy(input, values, _builder=None):
|
||||
raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
|
||||
values = [x.value for x in values]
|
||||
return semantic.max_constancy(input, values)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Debugging functions
|
||||
# -----------------------
|
||||
@@ -1739,12 +1768,12 @@ def inline_asm_elementwise(asm: str, constraints: str, args: list, dtype, is_pur
|
||||
broadcast_arg = dispatch_args[0]
|
||||
# Get the broadcast shape over all the arguments
|
||||
for i, item in enumerate(dispatch_args):
|
||||
_, broadcast_arg = semantic.binary_op_type_checking_impl(
|
||||
item, broadcast_arg, _builder, arithmetic_check=False)
|
||||
_, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder,
|
||||
arithmetic_check=False)
|
||||
# Change the shape of each argument based on the broadcast shape
|
||||
for i in range(len(dispatch_args)):
|
||||
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
|
||||
dispatch_args[i], broadcast_arg, _builder, arithmetic_check=False)
|
||||
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder,
|
||||
arithmetic_check=False)
|
||||
ret_shape = broadcast_arg.shape
|
||||
res_ty = block_type(dtype, ret_shape)
|
||||
call = _builder.create_inline_asm(asm, constraints, [t.handle for t in args], res_ty.to_ir(_builder), is_pure, pack)
|
||||
@@ -1757,7 +1786,6 @@ def inline_asm_elementwise(asm: str, constraints: str, args: list, dtype, is_pur
|
||||
|
||||
|
||||
class static_range:
|
||||
|
||||
"""
|
||||
Iterator that counts upward forever.
|
||||
|
||||
@@ -1801,7 +1829,9 @@ class static_range:
|
||||
# Extern functions
|
||||
# -----------------------
|
||||
|
||||
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, is_pure: bool, _builder=None):
|
||||
|
||||
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple,
|
||||
is_pure: bool, _builder=None):
|
||||
'''
|
||||
Dispatch a function to a library
|
||||
:param func: the function to dispatch
|
||||
@@ -1843,7 +1873,8 @@ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dic
|
||||
return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder), is_pure), ret_type)
|
||||
|
||||
|
||||
def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, _builder=None):
|
||||
def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool,
|
||||
_builder=None):
|
||||
'''
|
||||
Dispatch an elementwise function to a library
|
||||
:param lib_name: the name of the library
|
||||
@@ -1872,12 +1903,12 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol
|
||||
broadcast_arg = dispatch_args[0]
|
||||
# Get the broadcast shape over all the arguments
|
||||
for i, item in enumerate(dispatch_args):
|
||||
_, broadcast_arg = semantic.binary_op_type_checking_impl(
|
||||
item, broadcast_arg, _builder, arithmetic_check=arithmetic_check)
|
||||
_, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder,
|
||||
arithmetic_check=arithmetic_check)
|
||||
# Change the shape of each argument based on the broadcast shape
|
||||
for i in range(len(dispatch_args)):
|
||||
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
|
||||
dispatch_args[i], broadcast_arg, _builder, arithmetic_check=arithmetic_check)
|
||||
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder,
|
||||
arithmetic_check=arithmetic_check)
|
||||
if not all_scalar:
|
||||
ret_shape = broadcast_arg.shape
|
||||
func = getattr(_builder, "create_extern_elementwise")
|
||||
|
||||
@@ -3,16 +3,14 @@ from .. import core
|
||||
|
||||
@core.extern
|
||||
def globaltimer(_builder=None):
|
||||
return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [],
|
||||
dtype=core.int64, is_pure=False,
|
||||
pack=1, _builder=_builder)
|
||||
return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], dtype=core.int64, is_pure=False, pack=1,
|
||||
_builder=_builder)
|
||||
|
||||
|
||||
@core.extern
|
||||
def smid(_builder=None):
|
||||
return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [],
|
||||
dtype=core.int32, is_pure=True,
|
||||
pack=1, _builder=_builder)
|
||||
return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1,
|
||||
_builder=_builder)
|
||||
|
||||
|
||||
@core.builtin
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -91,6 +91,7 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
# two_to_the_minus_32: tl.constexpr = 2.328306e-10
|
||||
# return x * two_to_the_minus_32
|
||||
|
||||
|
||||
@jit
|
||||
def uint32_to_uniform_float(x):
|
||||
"""
|
||||
@@ -134,6 +135,7 @@ def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
u4 = uint32_to_uniform_float(i4)
|
||||
return u1, u2, u3, u4
|
||||
|
||||
|
||||
# -------------------
|
||||
# randn
|
||||
# -------------------
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -123,6 +123,7 @@ def maximum(x, y):
|
||||
"""
|
||||
return math.max(x, y)
|
||||
|
||||
|
||||
# max and argmax
|
||||
|
||||
|
||||
@@ -149,8 +150,7 @@ def _argmax_combine_tie_break_fast(value1, index1, value2, index2):
|
||||
|
||||
|
||||
@jit
|
||||
@core._add_reduction_docstr("maximum",
|
||||
return_indices_arg="return_indices",
|
||||
@core._add_reduction_docstr("maximum", return_indices_arg="return_indices",
|
||||
tie_break_arg="return_indices_tie_break_left")
|
||||
def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True):
|
||||
input = core._promote_reduction_input(input)
|
||||
@@ -175,6 +175,7 @@ def argmax(input, axis, tie_break_left=True):
|
||||
(_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left)
|
||||
return ret
|
||||
|
||||
|
||||
# min and argmin
|
||||
|
||||
|
||||
@@ -201,8 +202,7 @@ def _argmin_combine_tie_break_fast(value1, index1, value2, index2):
|
||||
|
||||
|
||||
@jit
|
||||
@core._add_reduction_docstr("minimum",
|
||||
return_indices_arg="return_indices",
|
||||
@core._add_reduction_docstr("minimum", return_indices_arg="return_indices",
|
||||
tie_break_arg="return_indices_tie_break_left")
|
||||
def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True):
|
||||
input = core._promote_reduction_input(input)
|
||||
@@ -222,8 +222,7 @@ def min(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr
|
||||
|
||||
|
||||
@jit
|
||||
@core._add_reduction_docstr("minimum index",
|
||||
tie_break_arg="tie_break_left")
|
||||
@core._add_reduction_docstr("minimum index", tie_break_arg="tie_break_left")
|
||||
def argmin(input, axis, tie_break_left=True):
|
||||
_, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left)
|
||||
return ret
|
||||
@@ -233,6 +232,7 @@ def argmin(input, axis, tie_break_left=True):
|
||||
def _sum_combine(a, b):
|
||||
return a + b
|
||||
|
||||
|
||||
# sum
|
||||
|
||||
|
||||
@@ -247,6 +247,7 @@ def sum(input, axis=None):
|
||||
def _xor_combine(a, b):
|
||||
return a ^ b
|
||||
|
||||
|
||||
# xor sum
|
||||
|
||||
|
||||
@@ -258,8 +259,8 @@ def xor_sum(input, axis=None, _builder=None, _generator=None):
|
||||
raise ValueError("xor_sum only supported for integers")
|
||||
|
||||
input = core._promote_reduction_input(input, _builder=_builder)
|
||||
return core.reduce(input, axis, _xor_combine,
|
||||
_builder=_builder, _generator=_generator)
|
||||
return core.reduce(input, axis, _xor_combine, _builder=_builder, _generator=_generator)
|
||||
|
||||
|
||||
# cumsum
|
||||
|
||||
@@ -271,6 +272,7 @@ def cumsum(input, axis=0):
|
||||
input = core._promote_reduction_input(input)
|
||||
return core.associative_scan(input, axis, _sum_combine)
|
||||
|
||||
|
||||
# cumprod
|
||||
|
||||
|
||||
|
||||
@@ -17,15 +17,14 @@ from ... import language as tl
|
||||
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
|
||||
})
|
||||
@jit
|
||||
def _sdd_kernel(
|
||||
A, B, C,
|
||||
stride_za, stride_ha, stride_ma, stride_ak,
|
||||
stride_zb, stride_hb, stride_bk, stride_nb,
|
||||
stride_zc, stride_hc, stride_mc, stride_nc,
|
||||
K, grid_offset, lut,
|
||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
||||
BLOCK: tl.constexpr, EVEN_K: tl.constexpr
|
||||
):
|
||||
def _sdd_kernel(A, B, C, #
|
||||
stride_za, stride_ha, stride_ma, stride_ak, #
|
||||
stride_zb, stride_hb, stride_bk, stride_nb, #
|
||||
stride_zc, stride_hc, stride_mc, stride_nc, #
|
||||
K, grid_offset, lut, #
|
||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, #
|
||||
BLOCK: tl.constexpr, EVEN_K: tl.constexpr #
|
||||
):
|
||||
# ------------ #
|
||||
# - Prologue - #
|
||||
# ------------ #
|
||||
@@ -104,13 +103,13 @@ def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=
|
||||
c = out
|
||||
grid = [c.shape[1], 1, c.shape[0]]
|
||||
_sdd_kernel[grid](
|
||||
a, b, c,
|
||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
|
||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
||||
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
|
||||
Ka, 0, lut,
|
||||
TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4,
|
||||
num_warps=4,
|
||||
a, b, c, #
|
||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), #
|
||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), #
|
||||
c.stride(0), c.stride(1), c.stride(2), c.stride(3), #
|
||||
Ka, 0, lut, #
|
||||
TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4, #
|
||||
num_warps=4 #
|
||||
)
|
||||
return c
|
||||
|
||||
@@ -120,6 +119,7 @@ def sdd_lut(layout, block, device):
|
||||
lut = lut.contiguous()
|
||||
return lut, None
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Dense = Sparse x Dense (DSD)
|
||||
# This operation uses a look-up table that contains pre-computed pointer increments
|
||||
@@ -128,15 +128,14 @@ def sdd_lut(layout, block, device):
|
||||
|
||||
|
||||
@jit
|
||||
def _dsd_kernel(
|
||||
A, B, C,
|
||||
stride_az, stride_ha, stride_am, stride_ak,
|
||||
stride_zb, stride_hb, stride_bk, stride_bn,
|
||||
stride_zc, stride_hc, stride_cm, stride_cn,
|
||||
DS0, DS1, lut,
|
||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr
|
||||
):
|
||||
def _dsd_kernel(A, B, C, #
|
||||
stride_az, stride_ha, stride_am, stride_ak, #
|
||||
stride_zb, stride_hb, stride_bk, stride_bn, #
|
||||
stride_zc, stride_hc, stride_cm, stride_cn, #
|
||||
DS0, DS1, lut, #
|
||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, #
|
||||
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr #
|
||||
):
|
||||
# ------------ #
|
||||
# - Prologue - #
|
||||
# ------------ #
|
||||
@@ -229,13 +228,13 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=N
|
||||
# compute output
|
||||
grid = lambda meta: [cdiv(BS3, meta['TILE_N']), width, BS0]
|
||||
_dsd_kernel[grid](
|
||||
a, b, c,
|
||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
|
||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
||||
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
|
||||
BS3, AS1, lut,
|
||||
TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
|
||||
num_warps=4, GROUP_SIZE_M=4,
|
||||
a, b, c, #
|
||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), #
|
||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), #
|
||||
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), #
|
||||
BS3, AS1, lut, #
|
||||
TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4, #
|
||||
num_warps=4, GROUP_SIZE_M=4 #
|
||||
)
|
||||
# exit()
|
||||
return c
|
||||
@@ -337,6 +336,7 @@ def dsd_lut(layout, block, step, trans, device):
|
||||
# create locks
|
||||
return lut, width
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Dense = Dense x Sparse (DDS)
|
||||
# -----------------------------
|
||||
@@ -346,6 +346,7 @@ def dsd_lut(layout, block, step, trans, device):
|
||||
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
|
||||
return dsd_matmul(b, a, not trans_b, not trans_a, not trans_c, spdims, block, lut, width, out=out)
|
||||
|
||||
|
||||
##############
|
||||
# MAIN API #
|
||||
##############
|
||||
@@ -356,10 +357,8 @@ class _matmul(torch.autograd.Function):
|
||||
fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul}
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block,
|
||||
c_lut, c_width, da_lut, da_width, db_lut, db_width, out
|
||||
):
|
||||
def forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_width, da_lut, da_width, db_lut,
|
||||
db_width, out):
|
||||
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out)
|
||||
# save for backward
|
||||
ctx.save_for_backward(a, b)
|
||||
@@ -385,15 +384,13 @@ class _matmul(torch.autograd.Function):
|
||||
# gradients w.r.t. a
|
||||
if ctx.needs_input_grad[0]:
|
||||
mode_da = mode[1] + mode[0] + mode[2]
|
||||
da = _matmul.fn[mode_da](
|
||||
dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_width,
|
||||
)
|
||||
da = _matmul.fn[mode_da](dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block,
|
||||
ctx.da_lut, ctx.da_width)
|
||||
# gradients w.r.t. b
|
||||
if ctx.needs_input_grad[1]:
|
||||
mode_db = mode[2] + mode[1] + mode[0]
|
||||
db = _matmul.fn[mode_db](
|
||||
a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_width,
|
||||
)
|
||||
db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block,
|
||||
ctx.db_lut, ctx.db_width)
|
||||
dout = dc if ctx.has_out else None
|
||||
return da, db, None, None, None, \
|
||||
None, None, None, None, \
|
||||
@@ -427,11 +424,9 @@ class matmul:
|
||||
self.db_lut, self.db_width = sdd_lut(layout, block, device)
|
||||
|
||||
def __call__(self, a, b, out=None):
|
||||
c = _matmul.apply(
|
||||
a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block,
|
||||
self.c_lut, self.c_width,
|
||||
self.da_lut, self.da_width,
|
||||
self.db_lut, self.db_width,
|
||||
out
|
||||
)
|
||||
c = _matmul.apply(a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, #
|
||||
self.c_lut, self.c_width, #
|
||||
self.da_lut, self.da_width, #
|
||||
self.db_lut, self.db_width, #
|
||||
out)
|
||||
return c
|
||||
|
||||
@@ -18,14 +18,13 @@ def num_warps(n):
|
||||
|
||||
|
||||
@jit
|
||||
def _blocksparse_softmax_fwd(
|
||||
Out, A, stride_xz, LUT,
|
||||
R, extent, stride_zr, stride_hr, # relative attention
|
||||
scale, is_causal,
|
||||
ROW_SIZE: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
IS_DENSE: tl.constexpr,
|
||||
):
|
||||
def _blocksparse_softmax_fwd(Out, A, stride_xz, LUT, #
|
||||
R, extent, stride_zr, stride_hr, # relative attention
|
||||
scale, is_causal, #
|
||||
ROW_SIZE: tl.constexpr, #
|
||||
BLOCK_SIZE: tl.constexpr, #
|
||||
IS_DENSE: tl.constexpr #
|
||||
):
|
||||
h = tl.program_id(0)
|
||||
m = tl.program_id(1)
|
||||
z = tl.program_id(2)
|
||||
@@ -73,18 +72,16 @@ def _blocksparse_softmax_fwd(
|
||||
|
||||
|
||||
@jit
|
||||
def _blocksparse_softmax_bwd(
|
||||
DA, stride_zdx,
|
||||
DOut, stride_zdout,
|
||||
Out, stride_zout,
|
||||
scale,
|
||||
LUT,
|
||||
DR, extent, stride_zr, stride_hr, stride_er,
|
||||
is_causal,
|
||||
ROW_SIZE: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
IS_DENSE: tl.constexpr,
|
||||
):
|
||||
def _blocksparse_softmax_bwd(DA, stride_zdx, #
|
||||
DOut, stride_zdout, #
|
||||
Out, stride_zout, #
|
||||
scale, #
|
||||
LUT, #
|
||||
DR, extent, stride_zr, stride_hr, stride_er, #
|
||||
is_causal, #
|
||||
ROW_SIZE: tl.constexpr, #
|
||||
BLOCK_SIZE: tl.constexpr, #
|
||||
IS_DENSE: tl.constexpr):
|
||||
h = tl.program_id(0)
|
||||
m = tl.program_id(1)
|
||||
z = tl.program_id(2)
|
||||
@@ -133,6 +130,7 @@ def _blocksparse_softmax_bwd(
|
||||
|
||||
|
||||
class _softmax(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def make_lut(layout, block, device):
|
||||
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
|
||||
@@ -151,10 +149,7 @@ class _softmax(torch.autograd.Function):
|
||||
return lut, int(total_sizes.max())
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, a, scale, rel_logits, is_causal,
|
||||
spdims, block, lut, maxlut, is_dense
|
||||
):
|
||||
def forward(ctx, a, scale, rel_logits, is_causal, spdims, block, lut, maxlut, is_dense):
|
||||
if scale is not None and isinstance(scale, torch.Tensor):
|
||||
assert scale.device.type == "cpu"
|
||||
scale = scale.item()
|
||||
@@ -165,14 +160,14 @@ class _softmax(torch.autograd.Function):
|
||||
# enqueue kernel
|
||||
out = torch.empty_like(a)
|
||||
_blocksparse_softmax_fwd[grid](
|
||||
out, a, a.stride(0), lut,
|
||||
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn
|
||||
scale,
|
||||
is_causal,
|
||||
BLOCK_SIZE=block,
|
||||
ROW_SIZE=next_power_of_2(maxlut),
|
||||
IS_DENSE=is_dense,
|
||||
num_warps=num_warps(maxlut)
|
||||
out, a, a.stride(0), lut, #
|
||||
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn#
|
||||
scale, #
|
||||
is_causal, #
|
||||
BLOCK_SIZE=block, #
|
||||
ROW_SIZE=next_power_of_2(maxlut), #
|
||||
IS_DENSE=is_dense, #
|
||||
num_warps=num_warps(maxlut) #
|
||||
)
|
||||
# save to context
|
||||
# ctx.mark_dirty(x)
|
||||
@@ -201,28 +196,23 @@ class _softmax(torch.autograd.Function):
|
||||
grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M)
|
||||
da = torch.empty_like(dout)
|
||||
_blocksparse_softmax_bwd[grid](
|
||||
da, da.stride(0),
|
||||
dout, dout.stride(0),
|
||||
out, out.stride(0),
|
||||
ctx.scale,
|
||||
lut,
|
||||
dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2],
|
||||
ctx.is_causal,
|
||||
BLOCK_SIZE=ctx.block,
|
||||
ROW_SIZE=next_power_of_2(ctx.maxlut),
|
||||
IS_DENSE=ctx.is_dense,
|
||||
num_warps=num_warps(ctx.maxlut)
|
||||
da, da.stride(0), #
|
||||
dout, dout.stride(0), #
|
||||
out, out.stride(0), #
|
||||
ctx.scale, #
|
||||
lut, #
|
||||
dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2], #
|
||||
ctx.is_causal, #
|
||||
BLOCK_SIZE=ctx.block, #
|
||||
ROW_SIZE=next_power_of_2(ctx.maxlut), #
|
||||
IS_DENSE=ctx.is_dense, #
|
||||
num_warps=num_warps(ctx.maxlut) #
|
||||
)
|
||||
return (da, None, None, dr, None,
|
||||
None, None, None, None, None,
|
||||
None,
|
||||
None, None, None,
|
||||
None,
|
||||
None, None, None
|
||||
)
|
||||
return (da, None, None, dr, None, None, None, None, None, None, None, None, None, None, None, None, None, None)
|
||||
|
||||
|
||||
class softmax:
|
||||
|
||||
def __init__(self, layout, block, device, is_dense=False):
|
||||
self.spdims = layout.shape
|
||||
self.layout = layout
|
||||
@@ -233,8 +223,6 @@ class softmax:
|
||||
def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False):
|
||||
if rel_logits is not None and rel_logits.dtype != a.dtype:
|
||||
raise ValueError(f"relative position embedding must be {a.dtype}")
|
||||
a = _softmax.apply(
|
||||
a, scale, rel_logits, is_causal,
|
||||
self.spdims, self.block, self.lut, self.maxlut, self.is_dense,
|
||||
)
|
||||
a = _softmax.apply(a, scale, rel_logits, is_causal, self.spdims, self.block, self.lut, self.maxlut,
|
||||
self.is_dense)
|
||||
return a
|
||||
|
||||
@@ -59,6 +59,7 @@ def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
|
||||
|
||||
|
||||
class _cross_entropy(torch.autograd.Function):
|
||||
|
||||
@classmethod
|
||||
def forward(cls, ctx, logits, indices):
|
||||
# make sure we can use triton
|
||||
|
||||
@@ -15,20 +15,19 @@ from .. import language as tl
|
||||
|
||||
|
||||
@jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
L,
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, H, N_CTX,
|
||||
Z_H_N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
):
|
||||
def _fwd_kernel(Q, K, V, sm_scale, #
|
||||
L, #
|
||||
Out, #
|
||||
stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vn, stride_vk, #
|
||||
stride_oz, stride_oh, stride_om, stride_on, #
|
||||
Z, H, N_CTX, #
|
||||
Z_H_N_CTX, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
|
||||
BLOCK_N: tl.constexpr, #
|
||||
IS_CAUSAL: tl.constexpr #
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
qvk_offset = off_hz * stride_qh
|
||||
@@ -40,7 +39,7 @@ def _fwd_kernel(
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, vk_offset),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
order=(0, 1)
|
||||
order=(0, 1),
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V,
|
||||
@@ -48,7 +47,7 @@ def _fwd_kernel(
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(vk_offset, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
order=(1, 0),
|
||||
)
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
@@ -104,7 +103,7 @@ def _fwd_kernel(
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(vk_offset + start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
order=(1, 0),
|
||||
)
|
||||
# O_ptrs = Out + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk
|
||||
tl.store(O_block_ptr, acc.to(K.dtype.element_ty))
|
||||
@@ -112,9 +111,11 @@ def _fwd_kernel(
|
||||
|
||||
@jit
|
||||
def _bwd_preprocess(
|
||||
Out, DO,
|
||||
Out,
|
||||
DO,
|
||||
Delta,
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
D_HEAD: tl.constexpr,
|
||||
):
|
||||
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, D_HEAD)
|
||||
@@ -128,40 +129,48 @@ def _bwd_preprocess(
|
||||
|
||||
|
||||
@jit
|
||||
def _bwd_kernel_one_col_block(
|
||||
Q, K, V, sm_scale, qk_scale,
|
||||
Out, DO,
|
||||
DQ, DK, DV,
|
||||
L,
|
||||
D,
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
Z, H, N_CTX,
|
||||
off_hz, start_n, num_block,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
SEQUENCE_PARALLEL: tl.constexpr,
|
||||
CAUSAL: tl.constexpr,
|
||||
MMA_V3: tl.constexpr
|
||||
):
|
||||
if SEQUENCE_PARALLEL:
|
||||
DQ += stride_dqa.to(tl.int64) * start_n
|
||||
def _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, #
|
||||
Out, DO, #
|
||||
DQ, DK, DV, #
|
||||
L, #
|
||||
D, #
|
||||
Q_block_ptr, K_block_ptr, V_block_ptr, #
|
||||
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vn, stride_vk, #
|
||||
Z, H, N_CTX, #
|
||||
off_h, off_z, off_hz, start_n, num_block, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
|
||||
BLOCK_N: tl.constexpr, #
|
||||
SEQUENCE_PARALLEL: tl.constexpr, #
|
||||
CAUSAL: tl.constexpr, #
|
||||
MMA_V3: tl.constexpr #
|
||||
):
|
||||
if CAUSAL:
|
||||
lo = start_n * BLOCK_M
|
||||
else:
|
||||
lo = 0
|
||||
|
||||
Q_offset = (off_z * stride_qz + off_h * stride_qh) // stride_qm
|
||||
DQ_offset = off_z * stride_qz + off_h * stride_qh
|
||||
K_offset = (off_z * stride_kz + off_h * stride_kh) // stride_kn
|
||||
V_offset = (off_z * stride_vz + off_h * stride_vh) // stride_vn
|
||||
if SEQUENCE_PARALLEL:
|
||||
DQ_offset += stride_dqa.to(tl.int64) * start_n
|
||||
DQ_offset = DQ_offset // stride_qm
|
||||
|
||||
Q_block_ptr = tl.advance(Q_block_ptr, (lo + Q_offset, 0))
|
||||
K_block_ptr = tl.advance(K_block_ptr, (start_n * BLOCK_M + K_offset, 0))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (start_n * BLOCK_M + V_offset, 0))
|
||||
DO_block_ptr = tl.advance(DO_block_ptr, (lo + Q_offset, 0))
|
||||
DQ_block_ptr = tl.advance(DQ_block_ptr, (lo + DQ_offset, 0))
|
||||
DK_block_ptr = tl.advance(DK_block_ptr, (start_n * BLOCK_M + K_offset, 0))
|
||||
DV_block_ptr = tl.advance(DV_block_ptr, (start_n * BLOCK_M + V_offset, 0))
|
||||
|
||||
# initialize row/col offsets
|
||||
offs_qm = lo + tl.arange(0, BLOCK_M)
|
||||
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_m = tl.arange(0, BLOCK_N)
|
||||
offs_k = tl.arange(0, BLOCK_DMODEL)
|
||||
# initialize pointers to value-like data
|
||||
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
||||
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk)
|
||||
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
# pointer to row-wise quantities in value-like data
|
||||
D_ptrs = D + off_hz * N_CTX
|
||||
l_ptrs = L + off_hz * N_CTX
|
||||
@@ -169,17 +178,17 @@ def _bwd_kernel_one_col_block(
|
||||
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# k and v stay in SRAM throughout
|
||||
k = tl.load(k_ptrs)
|
||||
v = tl.load(v_ptrs)
|
||||
k = tl.load(K_block_ptr)
|
||||
v = tl.load(V_block_ptr)
|
||||
# loop over rows
|
||||
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
|
||||
offs_m_curr = start_m + offs_m
|
||||
# load q, k, v, do on-chip
|
||||
q = tl.load(q_ptrs)
|
||||
q = tl.load(Q_block_ptr)
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
# NOTE: `do` is pre-divided by `l`; no normalization here
|
||||
if CAUSAL:
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf"))
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.0), float("-inf"))
|
||||
else:
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, tl.trans(k))
|
||||
@@ -187,7 +196,7 @@ def _bwd_kernel_one_col_block(
|
||||
l_i = tl.load(l_ptrs + offs_m_curr)
|
||||
p = tl.math.exp2(qk - l_i[:, None])
|
||||
# compute dv
|
||||
do = tl.load(do_ptrs)
|
||||
do = tl.load(DO_block_ptr)
|
||||
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do, allow_tf32=True)
|
||||
# compute dp = dot(v, do)
|
||||
Di = tl.load(D_ptrs + offs_m_curr)
|
||||
@@ -199,97 +208,156 @@ def _bwd_kernel_one_col_block(
|
||||
dk += tl.dot(tl.trans(ds), q, allow_tf32=True)
|
||||
# compute dq
|
||||
if not SEQUENCE_PARALLEL:
|
||||
dq = tl.load(dq_ptrs)
|
||||
dq = tl.load(DQ_block_ptr)
|
||||
dq += tl.dot(ds, k, allow_tf32=True)
|
||||
tl.store(dq_ptrs, dq)
|
||||
tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty))
|
||||
elif SEQUENCE_PARALLEL:
|
||||
if MMA_V3:
|
||||
dq = tl.dot(ds, k, allow_tf32=True)
|
||||
else:
|
||||
# not work with mma v3, becuase M % 64 != 0
|
||||
dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds), allow_tf32=True))
|
||||
tl.store(dq_ptrs, dq)
|
||||
tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty))
|
||||
|
||||
# increment pointers
|
||||
dq_ptrs += BLOCK_M * stride_qm
|
||||
q_ptrs += BLOCK_M * stride_qm
|
||||
do_ptrs += BLOCK_M * stride_qm
|
||||
DQ_block_ptr = tl.advance(DQ_block_ptr, (BLOCK_M, 0))
|
||||
Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_M, 0))
|
||||
DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_M, 0))
|
||||
# write-back
|
||||
dv_ptrs = DV + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk)
|
||||
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
||||
tl.store(dv_ptrs, dv)
|
||||
tl.store(dk_ptrs, dk)
|
||||
tl.store(DV_block_ptr, dv.to(V.dtype.element_ty))
|
||||
tl.store(DK_block_ptr, dk.to(K.dtype.element_ty))
|
||||
|
||||
|
||||
@jit
|
||||
def _bwd_kernel(
|
||||
# fmt: off
|
||||
Q, K, V, sm_scale,
|
||||
Out, DO,
|
||||
DQ, DK, DV,
|
||||
L,
|
||||
D,
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
Z, H, N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
SEQUENCE_PARALLEL: tl.constexpr,
|
||||
CAUSAL: tl.constexpr,
|
||||
MMA_V3: tl.constexpr
|
||||
# fmt: on
|
||||
):
|
||||
def _bwd_kernel(Q, K, V, sm_scale, #
|
||||
Out, DO, #
|
||||
DQ, DK, DV, #
|
||||
L, #
|
||||
D, #
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vn, stride_vk, #
|
||||
Z, H, N_CTX, #
|
||||
Z_H_N_CTX, #
|
||||
SQ_Z_H_N_CTX, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
|
||||
BLOCK_N: tl.constexpr, #
|
||||
SEQUENCE_PARALLEL: tl.constexpr, #
|
||||
CAUSAL: tl.constexpr, #
|
||||
MMA_V3: tl.constexpr #
|
||||
):
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
off_hz = tl.program_id(0)
|
||||
off_z = off_hz // H
|
||||
off_h = off_hz % H
|
||||
# offset pointers for batch/head
|
||||
Q += off_z * stride_qz + off_h * stride_qh
|
||||
K += off_z * stride_kz + off_h * stride_kh
|
||||
V += off_z * stride_vz + off_h * stride_vh
|
||||
DO += off_z * stride_qz + off_h * stride_qh
|
||||
DQ += off_z * stride_qz + off_h * stride_qh
|
||||
DK += off_z * stride_kz + off_h * stride_kh
|
||||
DV += off_z * stride_vz + off_h * stride_vh
|
||||
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q,
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K,
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_kn, stride_kk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V,
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
DO_block_ptr = tl.make_block_ptr(
|
||||
base=DO,
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
if SEQUENCE_PARALLEL:
|
||||
DQ_block_ptr = tl.make_block_ptr(
|
||||
base=DQ,
|
||||
shape=(SQ_Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
else:
|
||||
DQ_block_ptr = tl.make_block_ptr(
|
||||
base=DQ,
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
|
||||
DK_block_ptr = tl.make_block_ptr(
|
||||
base=DK,
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_kn, stride_kk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
DV_block_ptr = tl.make_block_ptr(
|
||||
base=DV,
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
|
||||
num_block_n = tl.cdiv(N_CTX, BLOCK_N)
|
||||
if not SEQUENCE_PARALLEL:
|
||||
for start_n in range(0, num_block_n):
|
||||
_bwd_kernel_one_col_block(
|
||||
Q, K, V, sm_scale, qk_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
L,
|
||||
D,
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
Z, H, N_CTX,
|
||||
off_hz, start_n, num_block_n,
|
||||
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_N=BLOCK_N,
|
||||
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
|
||||
CAUSAL=CAUSAL,
|
||||
MMA_V3=MMA_V3
|
||||
)
|
||||
_bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, #
|
||||
DQ, DK, DV, #
|
||||
L, #
|
||||
D, #
|
||||
Q_block_ptr, K_block_ptr, V_block_ptr, #
|
||||
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vn, stride_vk, #
|
||||
Z, H, N_CTX, #
|
||||
off_h, off_z, off_hz, start_n, num_block_n, #
|
||||
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, #
|
||||
BLOCK_N=BLOCK_N, #
|
||||
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, #
|
||||
CAUSAL=CAUSAL, #
|
||||
MMA_V3=MMA_V3 #
|
||||
)
|
||||
else:
|
||||
start_n = tl.program_id(1)
|
||||
_bwd_kernel_one_col_block(
|
||||
Q, K, V, sm_scale, qk_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
L,
|
||||
D,
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
Z, H, N_CTX,
|
||||
off_hz, start_n, num_block_n,
|
||||
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_N=BLOCK_N,
|
||||
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
|
||||
CAUSAL=CAUSAL,
|
||||
MMA_V3=MMA_V3
|
||||
)
|
||||
_bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, #
|
||||
DQ, DK, DV, #
|
||||
L, #
|
||||
D, #
|
||||
Q_block_ptr, K_block_ptr, V_block_ptr, #
|
||||
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #
|
||||
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vn, stride_vk, #
|
||||
Z, H, N_CTX, #
|
||||
off_h, off_z, off_hz, start_n, num_block_n, #
|
||||
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, #
|
||||
BLOCK_N=BLOCK_N, #
|
||||
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, #
|
||||
CAUSAL=CAUSAL, #
|
||||
MMA_V3=MMA_V3 #
|
||||
)
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
@@ -315,19 +383,20 @@ class _attention(torch.autograd.Function):
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
_fwd_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
L,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
q.shape[0] * q.shape[1] * q.shape[2],
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
|
||||
IS_CAUSAL=causal,
|
||||
num_warps=num_warps,
|
||||
num_stages=4)
|
||||
q, k, v, sm_scale, #
|
||||
L, #
|
||||
o, #
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
|
||||
q.shape[0], q.shape[1], q.shape[2], #
|
||||
q.shape[0] * q.shape[1] * q.shape[2], #
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, #
|
||||
IS_CAUSAL=causal, #
|
||||
num_warps=num_warps, #
|
||||
num_stages=4 #
|
||||
)
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, L)
|
||||
ctx.grid = grid
|
||||
@@ -348,35 +417,39 @@ class _attention(torch.autograd.Function):
|
||||
do = do.contiguous()
|
||||
if sequence_parallel:
|
||||
replicas = cdiv(seq_len_kv, BLOCK)
|
||||
new_dq_shape = (replicas,) + q.shape
|
||||
new_dq_shape = (replicas, ) + q.shape
|
||||
dq = torch.zeros(new_dq_shape, device=q.device, dtype=q.dtype)
|
||||
else:
|
||||
dq = torch.zeros_like(q, dtype=torch.float32)
|
||||
dq = torch.zeros_like(q, dtype=q.dtype)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
delta = torch.empty_like(L)
|
||||
_bwd_preprocess[(cdiv(q.shape[2], BLOCK) * ctx.grid[1], )](
|
||||
o, do,
|
||||
o,
|
||||
do,
|
||||
delta,
|
||||
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
BLOCK_M=BLOCK,
|
||||
D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
_bwd_kernel[(ctx.grid[1], cdiv(seq_len_kv, BLOCK) if sequence_parallel else 1)](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do,
|
||||
dq, dk, dv,
|
||||
L,
|
||||
delta,
|
||||
o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
|
||||
SEQUENCE_PARALLEL=sequence_parallel,
|
||||
CAUSAL=ctx.causal,
|
||||
MMA_V3=MMA_V3,
|
||||
num_warps=8,
|
||||
num_stages=1,
|
||||
q, k, v, ctx.sm_scale, #
|
||||
o, do, #
|
||||
dq, dk, dv, #
|
||||
L, #
|
||||
delta, #
|
||||
o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
|
||||
q.shape[0], q.shape[1], q.shape[2], #
|
||||
q.shape[0] * q.shape[1] * q.shape[2], #
|
||||
cdiv(seq_len_kv, BLOCK) * q.shape[0] * q.shape[1] * q.shape[2], #
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK, #
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, #
|
||||
SEQUENCE_PARALLEL=sequence_parallel, #
|
||||
CAUSAL=ctx.causal, #
|
||||
MMA_V3=MMA_V3, #
|
||||
num_warps=8, #
|
||||
num_stages=1 #
|
||||
)
|
||||
|
||||
if len(dq.shape) == 5:
|
||||
|
||||
@@ -37,8 +37,9 @@ def get_configs_io_bound():
|
||||
num_stages=num_stages, num_warps=num_warps))
|
||||
# split_k
|
||||
for split_k in [2, 4, 8, 16]:
|
||||
configs.append(Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
configs.append(
|
||||
Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
return configs
|
||||
|
||||
|
||||
@@ -69,22 +70,22 @@ def get_configs_io_bound():
|
||||
prune_configs_by={
|
||||
'early_config_prune': early_config_prune,
|
||||
'perf_model': estimate_matmul_time,
|
||||
'top_k': 10
|
||||
'top_k': 10,
|
||||
},
|
||||
)
|
||||
@heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||
})
|
||||
@jit
|
||||
def _kernel(A, B, C, M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
dot_out_dtype: tl.constexpr,
|
||||
allow_tf32: tl.constexpr,
|
||||
fp8_fast_accum: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr
|
||||
def _kernel(A, B, C, M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
dot_out_dtype: tl.constexpr, #
|
||||
allow_tf32: tl.constexpr, #
|
||||
fp8_fast_accum: tl.constexpr, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr #
|
||||
):
|
||||
# matrix multiplication
|
||||
pid = tl.program_id(0)
|
||||
@@ -184,14 +185,15 @@ class _matmul(torch.autograd.Function):
|
||||
ab_dtype = False
|
||||
# launch kernel
|
||||
grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||
_kernel[grid](a, b, c, M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
dot_out_dtype=dot_out_dtype,
|
||||
allow_tf32=allow_tf32,
|
||||
fp8_fast_accum=fp8_fast_accum,
|
||||
GROUP_M=8, AB_DTYPE=ab_dtype)
|
||||
_kernel[grid](
|
||||
a, b, c, M, N, K, #
|
||||
a.stride(0), a.stride(1), #
|
||||
b.stride(0), b.stride(1), #
|
||||
c.stride(0), c.stride(1), #
|
||||
dot_out_dtype=dot_out_dtype, #
|
||||
allow_tf32=allow_tf32, #
|
||||
fp8_fast_accum=fp8_fast_accum, #
|
||||
GROUP_M=8, AB_DTYPE=ab_dtype)
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -5,8 +5,7 @@ import torch
|
||||
from .. import cdiv
|
||||
from .._C.libtriton.triton import runtime
|
||||
from ..runtime import driver
|
||||
from ..testing import (get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops,
|
||||
nvsmi)
|
||||
from ..testing import (get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops, nvsmi)
|
||||
|
||||
|
||||
def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
@@ -14,7 +13,8 @@ def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
total_warps = num_ctas * min(num_warps, 4)
|
||||
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
|
||||
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
|
||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, cur_sm_clock, backend, device)
|
||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(
|
||||
dtype, cur_sm_clock, backend, device)
|
||||
return tflops
|
||||
|
||||
|
||||
@@ -35,12 +35,12 @@ def get_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
|
||||
|
||||
def estimate_matmul_time(
|
||||
# backend, device,
|
||||
num_warps, num_stages,
|
||||
A, B, C,
|
||||
M, N, K,
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K,
|
||||
debug=False, **kwargs
|
||||
# backend, device,
|
||||
num_warps, num_stages, #
|
||||
A, B, C, #
|
||||
M, N, K, #
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, #
|
||||
debug=False, **kwargs #
|
||||
):
|
||||
''' return estimated running time in ms
|
||||
= max(compute, loading) + store '''
|
||||
@@ -149,8 +149,9 @@ def early_config_prune(configs, named_args):
|
||||
optimal_num_stages = ldgsts_latency / mma_cycles
|
||||
|
||||
# nearest stages, prefer large #stages
|
||||
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages)
|
||||
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
|
||||
nearest = heapq.nsmallest(
|
||||
2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages)
|
||||
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
|
||||
|
||||
for n in nearest:
|
||||
pruned_configs.append(n[0])
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune,
|
||||
heuristics)
|
||||
from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune, heuristics)
|
||||
from .driver import driver
|
||||
from .jit import (JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret,
|
||||
version_key)
|
||||
from .jit import JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret
|
||||
|
||||
__all__ = [
|
||||
"driver",
|
||||
@@ -12,7 +10,6 @@ __all__ = [
|
||||
"heuristics",
|
||||
"JITFunction",
|
||||
"KernelInterface",
|
||||
"version_key",
|
||||
"reinterpret",
|
||||
"TensorWrapper",
|
||||
"OutOfResources",
|
||||
|
||||
@@ -9,11 +9,10 @@ from .jit import KernelInterface
|
||||
|
||||
|
||||
class OutOfResources(Exception):
|
||||
|
||||
def __init__(self, required, limit, name):
|
||||
self.message = f'out of resource: {name}, '\
|
||||
f'Required: {required}, '\
|
||||
f'Hardware limit: {limit}'
|
||||
self.message += '. Reducing block sizes or `num_stages` may help.'
|
||||
self.message = (f"out of resource: {name}, Required: {required}, Hardware limit: {limit}. " +
|
||||
"Reducing block sizes or `num_stages` may help.")
|
||||
self.required = required
|
||||
self.limit = limit
|
||||
self.name = name
|
||||
@@ -25,38 +24,77 @@ class OutOfResources(Exception):
|
||||
|
||||
|
||||
class Autotuner(KernelInterface):
|
||||
<<<<<<< HEAD
|
||||
def __init__(self, fn, arg_names, configs, key, verbose, reset_to_zero, prune_configs_by: Dict = None, warmup=25, rep=100):
|
||||
'''
|
||||
=======
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fn,
|
||||
arg_names,
|
||||
configs,
|
||||
key,
|
||||
reset_to_zero,
|
||||
restore_value,
|
||||
prune_configs_by: Dict = None,
|
||||
warmup=25,
|
||||
rep=100,
|
||||
):
|
||||
"""
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
|
||||
'''
|
||||
"""
|
||||
if not configs:
|
||||
self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)]
|
||||
else:
|
||||
self.configs = configs
|
||||
self.key_idx = [arg_names.index(k) for k in key]
|
||||
self.cache = {}
|
||||
# hook to reset all required tensor to zeros before relaunching a kernel
|
||||
self.hook = lambda args: 0
|
||||
self.arg_names = arg_names
|
||||
|
||||
# Reset to zero or restore values
|
||||
self.reset_idx = []
|
||||
if reset_to_zero is not None:
|
||||
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
|
||||
self.restore_idx = []
|
||||
if restore_value is not None:
|
||||
self.restore_idx = [arg_names.index(k) for k in restore_value]
|
||||
|
||||
def _hook(args):
|
||||
# Hook to reset or restore for required tensors
|
||||
self.pre_hook = lambda args, reset_only=False: 0
|
||||
self.post_hook = lambda args: 0
|
||||
if len(self.reset_idx) > 0 or len(self.restore_idx) > 0:
|
||||
|
||||
def _pre_hook(args, reset_only=False):
|
||||
for i in self.reset_idx:
|
||||
args[i].zero_()
|
||||
self.hook = _hook
|
||||
self.arg_names = arg_names
|
||||
# prune configs
|
||||
if not reset_only:
|
||||
self.restore_copies = [args[i].clone() for i in self.restore_idx]
|
||||
|
||||
self.pre_hook = _pre_hook
|
||||
if len(self.restore_idx) > 0:
|
||||
|
||||
def _post_hook(args):
|
||||
for i, j in enumerate(self.restore_idx):
|
||||
args[j].copy_(self.restore_copies[i])
|
||||
self.restore_copies = []
|
||||
|
||||
self.post_hook = _post_hook
|
||||
|
||||
# Prune configs
|
||||
if prune_configs_by:
|
||||
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
|
||||
if 'early_config_prune' in prune_configs_by:
|
||||
early_config_prune = prune_configs_by['early_config_prune']
|
||||
perf_model, top_k = prune_configs_by["perf_model"], prune_configs_by["top_k"]
|
||||
if "early_config_prune" in prune_configs_by:
|
||||
early_config_prune = prune_configs_by["early_config_prune"]
|
||||
else:
|
||||
perf_model, top_k, early_config_prune = None, None, None
|
||||
self.perf_model, self.configs_top_k = perf_model, top_k
|
||||
self.early_config_prune = early_config_prune
|
||||
|
||||
self.fn = fn
|
||||
self.warmup = warmup
|
||||
self.rep = rep
|
||||
@@ -67,10 +105,8 @@ class Autotuner(KernelInterface):
|
||||
# as kwargs and by the autotuner
|
||||
conflicts = meta.keys() & config.kwargs.keys()
|
||||
if conflicts:
|
||||
raise ValueError(
|
||||
f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
||||
" Make sure that you don't re-define auto-tuned symbols."
|
||||
)
|
||||
raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
||||
" Make sure that you don't re-define auto-tuned symbols.")
|
||||
# augment meta-parameters with tunable ones
|
||||
current = dict(meta, **config.kwargs)
|
||||
full_nargs = {**self.nargs, **current}
|
||||
@@ -78,16 +114,22 @@ class Autotuner(KernelInterface):
|
||||
def kernel_call():
|
||||
if config.pre_hook:
|
||||
config.pre_hook(full_nargs)
|
||||
self.hook(args)
|
||||
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization,
|
||||
# enable_persistent=False,
|
||||
**current)
|
||||
self.pre_hook(args)
|
||||
self.fn.run(
|
||||
*args,
|
||||
num_warps=config.num_warps,
|
||||
num_stages=config.num_stages,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization,
|
||||
# enable_persistent=False,
|
||||
**current,
|
||||
)
|
||||
self.post_hook(args)
|
||||
|
||||
try:
|
||||
return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
|
||||
except OutOfResources:
|
||||
return [float('inf'), float('inf'), float('inf')]
|
||||
return [float("inf"), float("inf"), float("inf")]
|
||||
|
||||
def get_best_config(self):
|
||||
return self.best_config
|
||||
@@ -110,12 +152,11 @@ class Autotuner(KernelInterface):
|
||||
# prune configs
|
||||
pruned_configs = self.prune_configs(kwargs)
|
||||
bench_start = time.time()
|
||||
timings = {config: self._bench(*args, config=config, **kwargs)
|
||||
for config in pruned_configs}
|
||||
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
|
||||
bench_end = time.time()
|
||||
self.bench_time = bench_end - bench_start
|
||||
self.cache[key] = builtins.min(timings, key=timings.get)
|
||||
self.hook(args)
|
||||
self.pre_hook(args, reset_only=True)
|
||||
self.configs_timings = timings
|
||||
if self.verbose:
|
||||
print(str(key) + ": " + str(self.cache[key]))
|
||||
@@ -126,9 +167,15 @@ class Autotuner(KernelInterface):
|
||||
full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs}
|
||||
if config.pre_hook is not None:
|
||||
config.pre_hook(full_nargs)
|
||||
ret = self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization, **kwargs, **config.kwargs)
|
||||
ret = self.fn.run(
|
||||
*args,
|
||||
num_warps=config.num_warps,
|
||||
num_stages=config.num_stages,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization,
|
||||
**kwargs,
|
||||
**config.kwargs,
|
||||
)
|
||||
self.nargs = None
|
||||
return ret
|
||||
|
||||
@@ -142,17 +189,20 @@ class Autotuner(KernelInterface):
|
||||
top_k = int(len(self.configs) * top_k)
|
||||
if len(pruned_configs) > top_k:
|
||||
est_timing = {
|
||||
config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages,
|
||||
num_warps=config.num_warps,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization,
|
||||
enable_persistent=config.enable_persistent)
|
||||
config:
|
||||
self.perf_model(
|
||||
**self.nargs,
|
||||
**kwargs,
|
||||
**config.kwargs,
|
||||
num_stages=config.num_stages,
|
||||
num_warps=config.num_warps,
|
||||
num_ctas=config.num_ctas,
|
||||
enable_warp_specialization=config.enable_warp_specialization,
|
||||
enable_persistent=config.enable_persistent,
|
||||
)
|
||||
for config in pruned_configs
|
||||
}
|
||||
pruned_configs = sorted(
|
||||
est_timing.keys(),
|
||||
key=lambda x: est_timing[x])[
|
||||
:top_k]
|
||||
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
||||
return pruned_configs
|
||||
|
||||
def warmup(self, *args, **kwargs):
|
||||
@@ -195,13 +245,14 @@ class Config:
|
||||
self.num_ctas = num_ctas
|
||||
self.num_stages = num_stages
|
||||
self.enable_warp_specialization = enable_warp_specialization
|
||||
# TODO[shuhaoj]: May make enable_persistent configurable in future if necessay.
|
||||
# TODO[shuhaoj]: May make enable_persistent configurable in future if necessary.
|
||||
self.enable_persistent = False
|
||||
self.pre_hook = pre_hook
|
||||
|
||||
def __str__(self):
|
||||
res = []
|
||||
for k, v in self.kwargs.items():
|
||||
<<<<<<< HEAD
|
||||
res.append(f'{k}: {v}')
|
||||
res.append(f'num_warps: {self.num_warps}')
|
||||
## Comment out Hopper specific parameters
|
||||
@@ -214,6 +265,18 @@ class Config:
|
||||
|
||||
|
||||
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=False, warmup=25, rep=100):
|
||||
=======
|
||||
res.append(f"{k}: {v}")
|
||||
res.append(f"num_warps: {self.num_warps}")
|
||||
res.append(f"num_ctas: {self.num_ctas}")
|
||||
res.append(f"num_stages: {self.num_stages}")
|
||||
res.append(f"enable_warp_specialization: {self.enable_warp_specialization}")
|
||||
res.append(f"enable_persistent: {self.enable_persistent}")
|
||||
return ", ".join(res)
|
||||
|
||||
|
||||
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, warmup=25, rep=100):
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
"""
|
||||
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
||||
|
||||
@@ -244,6 +307,8 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=Fa
|
||||
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
|
||||
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
||||
:type reset_to_zero: list[str]
|
||||
:param restore_value: a list of argument names whose value will be restored after evaluating any configs.
|
||||
:type restore_value: list[str]
|
||||
:param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25.
|
||||
:type warmup: int
|
||||
:param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100.
|
||||
@@ -251,8 +316,13 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=Fa
|
||||
:param verbose: a boolean that controls whether the best_config for each key is printed
|
||||
:type verbose: bool
|
||||
"""
|
||||
|
||||
def decorator(fn):
|
||||
<<<<<<< HEAD
|
||||
return Autotuner(fn, fn.arg_names, configs, key, verbose, reset_to_zero, prune_configs_by, warmup, rep)
|
||||
=======
|
||||
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, prune_configs_by, warmup, rep)
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -286,6 +356,7 @@ def heuristics(values):
|
||||
each such function takes a list of positional arguments as input.
|
||||
:type values: dict[str, Callable[[list[Any]], Any]]
|
||||
"""
|
||||
|
||||
def decorator(fn):
|
||||
return Heuristics(fn, fn.arg_names, values)
|
||||
|
||||
|
||||
@@ -1,27 +1,42 @@
|
||||
#include "cuda.h"
|
||||
#include <dlfcn.h>
|
||||
#include <stdbool.h>
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
|
||||
static inline void gpuAssert(CUresult code, const char *file, int line) {
|
||||
if (code != CUDA_SUCCESS) {
|
||||
const char *prefix = "Triton Error [CUDA]: ";
|
||||
const char *str;
|
||||
cuGetErrorString(code, &str);
|
||||
char err[1024] = {0};
|
||||
strcat(err, prefix);
|
||||
strcat(err, str);
|
||||
PyGILState_STATE gil_state;
|
||||
gil_state = PyGILState_Ensure();
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
PyGILState_Release(gil_state);
|
||||
}
|
||||
// Raises a Python exception and returns false if code is not CUDA_SUCCESS.
|
||||
static bool gpuAssert(CUresult code, const char *file, int line) {
|
||||
if (code == CUDA_SUCCESS)
|
||||
return true;
|
||||
|
||||
const char *prefix = "Triton Error [CUDA]: ";
|
||||
const char *str;
|
||||
cuGetErrorString(code, &str);
|
||||
char err[1024] = {0};
|
||||
strcat(err, prefix);
|
||||
strcat(err, str);
|
||||
PyGILState_STATE gil_state;
|
||||
gil_state = PyGILState_Ensure();
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
PyGILState_Release(gil_state);
|
||||
return false;
|
||||
}
|
||||
|
||||
#define CUDA_CHECK(ans) \
|
||||
{ \
|
||||
{ gpuAssert((ans), __FILE__, __LINE__); } \
|
||||
}
|
||||
// To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block.
|
||||
#define CUDA_CHECK_AND_RETURN_NULL(ans) \
|
||||
do { \
|
||||
if (!gpuAssert((ans), __FILE__, __LINE__)) \
|
||||
return NULL; \
|
||||
} while (0)
|
||||
|
||||
// To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block.
|
||||
#define CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(ans) \
|
||||
do { \
|
||||
if (!gpuAssert((ans), __FILE__, __LINE__)) { \
|
||||
PyEval_RestoreThread(_save); \
|
||||
return NULL; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define ADD_ENUM_ITEM(value) \
|
||||
do { \
|
||||
@@ -200,16 +215,16 @@ static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
|
||||
int sm_clock_rate;
|
||||
int mem_clock_rate;
|
||||
int mem_bus_width;
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
|
||||
&max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
||||
device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
|
||||
&multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate,
|
||||
CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
|
||||
&sm_clock_rate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
|
||||
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
|
||||
&mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
|
||||
&mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device));
|
||||
|
||||
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem",
|
||||
@@ -237,33 +252,37 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
|
||||
CUcontext pctx = 0;
|
||||
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
CUDA_CHECK(cuCtxGetCurrent(&pctx));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&pctx));
|
||||
if (!pctx) {
|
||||
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
|
||||
CUDA_CHECK(cuCtxSetCurrent(pctx));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
||||
cuDevicePrimaryCtxRetain(&pctx, device));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(pctx));
|
||||
}
|
||||
|
||||
CUDA_CHECK(cuModuleLoadData(&mod, data));
|
||||
CUDA_CHECK(cuModuleGetFunction(&fun, mod, name));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuModuleLoadData(&mod, data));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
||||
cuModuleGetFunction(&fun, mod, name));
|
||||
// get allocated registers and spilled registers from the function
|
||||
CUDA_CHECK(cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
|
||||
CUDA_CHECK(
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
||||
cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
||||
cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
|
||||
n_spills /= 4;
|
||||
// set dynamic shared memory if necessary
|
||||
int shared_optin;
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute(
|
||||
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
||||
device));
|
||||
if (shared > 49152 && shared_optin > 49152) {
|
||||
CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
||||
cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
|
||||
int shared_total, shared_static;
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute(
|
||||
&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
|
||||
device));
|
||||
CUDA_CHECK(cuFuncGetAttribute(&shared_static,
|
||||
CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
|
||||
CUDA_CHECK(
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute(
|
||||
&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
||||
cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
shared_optin - shared_static));
|
||||
}
|
||||
@@ -286,7 +305,7 @@ static PyObject *memAlloc(PyObject *self, PyObject *args) {
|
||||
}
|
||||
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
CUDA_CHECK(cuMemAlloc(&dptr, bytesize));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuMemAlloc(&dptr, bytesize));
|
||||
Py_END_ALLOW_THREADS;
|
||||
|
||||
return PyLong_FromUnsignedLongLong((unsigned long long)dptr);
|
||||
@@ -307,7 +326,8 @@ static PyObject *memcpyHtoD(PyObject *self, PyObject *args) {
|
||||
srcHost = (const void *)srcHostPtr;
|
||||
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
CUDA_CHECK(cuMemcpyHtoD(dstDevice, srcHost, byteCount));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
||||
cuMemcpyHtoD(dstDevice, srcHost, byteCount));
|
||||
Py_END_ALLOW_THREADS;
|
||||
|
||||
Py_RETURN_NONE;
|
||||
@@ -321,7 +341,7 @@ static PyObject *memFree(PyObject *self, PyObject *args) {
|
||||
}
|
||||
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
CUDA_CHECK(cuMemFree(dptr));
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuMemFree(dptr));
|
||||
Py_END_ALLOW_THREADS;
|
||||
|
||||
Py_RETURN_NONE;
|
||||
@@ -411,7 +431,7 @@ static PyObject *tensorMapEncodeTiled(PyObject *self, PyObject *args) {
|
||||
}
|
||||
// Call the function
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
CUDA_CHECK(cuTensorMapEncodeTiledHandle(
|
||||
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuTensorMapEncodeTiledHandle(
|
||||
tensorMap, tensorDataType, tensorRank, globalAddress, globalDim,
|
||||
globalStrides, boxDim, elementStrides, interleave, swizzle, l2Promotion,
|
||||
oobFill));
|
||||
|
||||
@@ -19,6 +19,7 @@ def default_dump_dir():
|
||||
|
||||
|
||||
class CacheManager(ABC):
|
||||
|
||||
def __init__(self, key):
|
||||
pass
|
||||
|
||||
@@ -44,20 +45,21 @@ class CacheManager(ABC):
|
||||
|
||||
|
||||
class FileCacheManager(CacheManager):
|
||||
|
||||
def __init__(self, key, override=False, dump=False):
|
||||
self.key = key
|
||||
self.lock_path = None
|
||||
if (dump):
|
||||
if dump:
|
||||
self.cache_dir = default_dump_dir()
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
elif (override):
|
||||
elif override:
|
||||
self.cache_dir = default_override_dir()
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
else:
|
||||
# create cache directory if it doesn't exist
|
||||
self.cache_dir = os.getenv('TRITON_CACHE_DIR', "").strip() or default_cache_dir()
|
||||
self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
|
||||
if self.cache_dir:
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||
@@ -93,9 +95,8 @@ class FileCacheManager(CacheManager):
|
||||
result = {}
|
||||
for c in child_paths:
|
||||
p = self._make_path(c)
|
||||
if not os.path.exists(p):
|
||||
raise Exception(f"Group file {p} does not exist from group {grp_filename} ")
|
||||
result[c] = p
|
||||
if os.path.exists(p):
|
||||
result[c] = p
|
||||
return result
|
||||
|
||||
# Note a group of pushed files as being part of a group
|
||||
@@ -142,6 +143,7 @@ def get_cache_manager(key) -> CacheManager:
|
||||
|
||||
if user_cache_manager is not None and user_cache_manager != __cache_cls_nme:
|
||||
import importlib
|
||||
|
||||
module_path, clz_nme = user_cache_manager.split(":")
|
||||
module = importlib.import_module(module_path)
|
||||
__cache_cls = getattr(module, clz_nme)
|
||||
|
||||
@@ -9,7 +9,6 @@ from .cache import get_cache_manager
|
||||
|
||||
|
||||
class DriverBase(metaclass=abc.ABCMeta):
|
||||
|
||||
CUDA = 0
|
||||
HIP = 1
|
||||
|
||||
@@ -19,6 +18,8 @@ class DriverBase(metaclass=abc.ABCMeta):
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# CUDA
|
||||
# -----------------------------
|
||||
@@ -27,7 +28,7 @@ class DriverBase(metaclass=abc.ABCMeta):
|
||||
class CudaUtils(object):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(CudaUtils, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
@@ -47,6 +48,7 @@ class CudaUtils(object):
|
||||
with open(so, "rb") as f:
|
||||
cache_path = cache.put(f.read(), fname, binary=True)
|
||||
import importlib.util
|
||||
|
||||
spec = importlib.util.spec_from_file_location("cuda_utils", cache_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
@@ -66,7 +68,7 @@ class CudaUtils(object):
|
||||
class CudaDriver(DriverBase):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(CudaDriver, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
@@ -74,14 +76,16 @@ class CudaDriver(DriverBase):
|
||||
self.utils = CudaUtils()
|
||||
self.backend = self.CUDA
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# HIP
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class HIPUtils(object):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(HIPUtils, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
@@ -101,6 +105,7 @@ class HIPUtils(object):
|
||||
with open(so, "rb") as f:
|
||||
cache_path = cache.put(f.read(), fname, binary=True)
|
||||
import importlib.util
|
||||
|
||||
spec = importlib.util.spec_from_file_location("hip_utils", cache_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
@@ -111,7 +116,7 @@ class HIPUtils(object):
|
||||
class HIPDriver(DriverBase):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(HIPDriver, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
@@ -123,7 +128,7 @@ class HIPDriver(DriverBase):
|
||||
class UnsupportedDriver(DriverBase):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(UnsupportedDriver, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
@@ -131,12 +136,14 @@ class UnsupportedDriver(DriverBase):
|
||||
self.utils = None
|
||||
self.backend = None
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Driver
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class LazyProxy:
|
||||
|
||||
def __init__(self, init_fn):
|
||||
self._init_fn = init_fn
|
||||
self._obj = None
|
||||
@@ -150,7 +157,7 @@ class LazyProxy:
|
||||
return getattr(self._obj, name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name in ['_init_fn', '_obj']:
|
||||
if name in ["_init_fn", "_obj"]:
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
self._initialize_obj()
|
||||
@@ -172,6 +179,7 @@ class LazyProxy:
|
||||
|
||||
def initialize_driver():
|
||||
import torch
|
||||
|
||||
if torch.version.hip is not None:
|
||||
return HIPDriver()
|
||||
elif torch.cuda.is_available():
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
|
||||
class OutOfResources(Exception):
|
||||
|
||||
def __init__(self, required, limit, name):
|
||||
self.message = f'out of resource: {name}, '\
|
||||
f'Required: {required}, '\
|
||||
f'Hardware limit: {limit}'
|
||||
self.message += '. Reducing block sizes or `num_stages` may help.'
|
||||
self.message = f"out of resource: {name}, " f"Required: {required}, " f"Hardware limit: {limit}"
|
||||
self.message += ". Reducing block sizes or `num_stages` may help."
|
||||
self.required = required
|
||||
self.limit = limit
|
||||
self.name = name
|
||||
|
||||
@@ -74,11 +74,15 @@ class BlockPointerHandle:
|
||||
|
||||
|
||||
def wrap_ret(compute_ret_ty):
|
||||
|
||||
def wrapper(fn):
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
ret = fn(*args, **kwargs)
|
||||
return TensorHandle(ret.data, compute_ret_ty(*args, **kwargs))
|
||||
|
||||
return wrapped
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@@ -249,11 +253,13 @@ class Builder:
|
||||
# ternary functions
|
||||
def ternary_op(self, lhs, rhs, other, op):
|
||||
return TensorHandle(op(lhs.data, rhs.data, other.data), other.dtype)
|
||||
|
||||
create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where)
|
||||
|
||||
# unary functions
|
||||
def unary_op(self, arg, op):
|
||||
return TensorHandle(op(arg.data), arg.dtype)
|
||||
|
||||
create_exp = lambda self, arg: self.unary_op(arg, np.exp)
|
||||
create_cos = lambda self, arg: self.unary_op(arg, np.cos)
|
||||
create_sin = lambda self, arg: self.unary_op(arg, np.sin)
|
||||
@@ -279,7 +285,8 @@ class Builder:
|
||||
dtype_tt = ptr.dtype.element_ty
|
||||
return TensorHandle(ptr.data + (dtype_tt.primitive_bitwidth // 8) * offset.data.astype(np.uint64), ptr.dtype)
|
||||
|
||||
def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, is_volatile):
|
||||
def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy,
|
||||
is_volatile):
|
||||
ptrs, masks = ptr.materialize_pointers(boundary_check)
|
||||
assert padding_option is None
|
||||
other = None
|
||||
@@ -297,6 +304,7 @@ class Builder:
|
||||
|
||||
def create_int_to_ptr(self, val, dst_ty):
|
||||
return TensorHandle(val.data.astype(np.uint64), dst_ty)
|
||||
|
||||
# def create_cat(self, lhs, rhs):
|
||||
# pass
|
||||
|
||||
@@ -360,7 +368,10 @@ class Builder:
|
||||
|
||||
|
||||
def patch_attr(obj, name, member, builder):
|
||||
new_member = lambda *args, member=member, **kwargs: (member(*args, **{k: v for k, v in kwargs.items() if k != '_builder'}, _builder=builder))
|
||||
new_member = lambda *args, member=member, **kwargs: (member(*args, **
|
||||
{k: v
|
||||
for k, v in kwargs.items()
|
||||
if k != "_builder"}, _builder=builder))
|
||||
setattr(obj, name, new_member)
|
||||
|
||||
|
||||
@@ -384,8 +395,8 @@ def _patch_lang_core(lang, builder):
|
||||
def _new_reduce(input, axis, combine_fn):
|
||||
fn = combine_fn.fn.__name__
|
||||
mapping = {
|
||||
'maximum': np.max,
|
||||
'_sum_combine': np.sum,
|
||||
"maximum": np.max,
|
||||
"_sum_combine": np.sum,
|
||||
}
|
||||
ret = mapping[fn](input.handle.data, axis=axis)
|
||||
ret_type = tl.block_type(input.dtype, ret.shape)
|
||||
@@ -397,15 +408,16 @@ def _patch_lang_core(lang, builder):
|
||||
def _patch_lang_math(lang, builder):
|
||||
math = lang.math
|
||||
mapping = {
|
||||
'abs': 'abs',
|
||||
'acos': 'arccos',
|
||||
'asin': 'arcsin',
|
||||
'exp2': 'exp2',
|
||||
'log2': 'log2',
|
||||
'max': 'maximum',
|
||||
"abs": "abs",
|
||||
"acos": "arccos",
|
||||
"asin": "arcsin",
|
||||
"exp2": "exp2",
|
||||
"log2": "log2",
|
||||
"max": "maximum",
|
||||
}
|
||||
|
||||
def make_numpy(name):
|
||||
|
||||
def impl(*args, **kwargs):
|
||||
ret_type = args[0].type # TODO: incorrect
|
||||
ret_dtype = args[0].dtype # TODO: incorrect
|
||||
@@ -414,15 +426,18 @@ def _patch_lang_math(lang, builder):
|
||||
ret = getattr(np, mapping[name])(*args, **kwargs)
|
||||
ret = tl.core.tensor(TensorHandle(ret, ret_dtype), ret_type)
|
||||
return ret
|
||||
|
||||
return impl
|
||||
|
||||
def make_fallback(name):
|
||||
|
||||
def fallback(*args, **kwargs):
|
||||
raise NotImplementedError(f"""
|
||||
{name} not supported in interpreter mode: no known numpy implementation.
|
||||
If you think that {name} in fact does have a numpy implementation, please add it
|
||||
to the mapping in python/triton/interpreter/new_interpreter.py:_patch_lang_math.
|
||||
""")
|
||||
|
||||
return fallback
|
||||
|
||||
for name, member in inspect.getmembers(math):
|
||||
@@ -438,7 +453,7 @@ def _implicit_cvt(arg):
|
||||
ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
|
||||
handle = TensorHandle(np.array([arg], dtype=np.int32), ty)
|
||||
return tl.tensor(handle, ty)
|
||||
if hasattr(arg, 'data_ptr'):
|
||||
if hasattr(arg, "data_ptr"):
|
||||
ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
|
||||
handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty)
|
||||
return tl.tensor(handle, ty)
|
||||
@@ -453,28 +468,29 @@ def _unwrap(tensor):
|
||||
|
||||
builder = Builder()
|
||||
|
||||
RESERVED_KWS = ['num_warps', 'num_stages', 'num_ctas', 'enable_warp_specialization', 'enable_fp_fusion']
|
||||
RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_warp_specialization", "enable_fp_fusion"]
|
||||
|
||||
|
||||
class GridExecutor:
|
||||
|
||||
def __init__(self, fn, arg_names, grid):
|
||||
from .jit import _normalize_ty # TODO: modularize
|
||||
|
||||
self.fn = fn
|
||||
self.arg_names = arg_names
|
||||
self.grid = grid
|
||||
__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
|
||||
self.constexprs = [name for name in arg_names if __annotations__.get(name) == 'constexpr']
|
||||
self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"]
|
||||
|
||||
def _patch_lang(self, builder):
|
||||
lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]]
|
||||
assert len(lang) == 1, "triton.language must be visible from within jit'd function"
|
||||
_patch_lang_tensor(getattr(lang[0], 'tensor'), builder)
|
||||
_patch_lang_tensor(getattr(lang[0], "tensor"), builder)
|
||||
_patch_lang_core(lang[0], builder)
|
||||
_patch_lang_math(lang[0], builder)
|
||||
|
||||
def __call__(self, *args_dev, **kwargs):
|
||||
args_hst = [_unwrap(arg).cpu() if hasattr(arg, 'data_ptr') else arg for arg in args_dev]
|
||||
args_hst = [_unwrap(arg).cpu() if hasattr(arg, "data_ptr") else arg for arg in args_dev]
|
||||
# removes reserved keywords from kwargs
|
||||
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS}
|
||||
# remaps core language functions to interpreted ones
|
||||
@@ -486,7 +502,7 @@ class GridExecutor:
|
||||
# iterate through grid
|
||||
grid = self.grid(args) if callable(self.grid) else self.grid
|
||||
assert len(grid) <= 3
|
||||
grid = grid + (1,) * (3 - len(grid))
|
||||
grid = grid + (1, ) * (3 - len(grid))
|
||||
builder.set_grid_dim(*grid)
|
||||
for x in range(grid[0]):
|
||||
for y in range(grid[1]):
|
||||
@@ -495,7 +511,7 @@ class GridExecutor:
|
||||
self.fn(**args)
|
||||
# copy arguments back to propagate side-effects
|
||||
for arg_dev, arg_hst in zip(args_dev, args_hst):
|
||||
if hasattr(arg_dev, 'data_ptr'):
|
||||
if hasattr(arg_dev, "data_ptr"):
|
||||
_unwrap(arg_dev).copy_(arg_hst.to(arg_dev.device))
|
||||
|
||||
|
||||
@@ -504,17 +520,18 @@ class InterpretedFunction:
|
||||
def _patch_lang(self, builder):
|
||||
lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]]
|
||||
assert len(lang) == 1, "triton.language must be visible from within jit'd function"
|
||||
_patch_lang_tensor(getattr(lang[0], 'tensor'), builder)
|
||||
_patch_lang_tensor(getattr(lang[0], "tensor"), builder)
|
||||
_patch_lang_core(lang[0], builder)
|
||||
|
||||
def __init__(self, fn) -> None:
|
||||
self.fn = fn
|
||||
|
||||
def run(*args, **kwargs):
|
||||
grid = kwargs['grid']
|
||||
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS + ['grid']}
|
||||
grid = kwargs["grid"]
|
||||
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS + ["grid"]}
|
||||
|
||||
return GridExecutor(self.fn, self.arg_names, grid)(*args, **kwargs)
|
||||
|
||||
self.run = run
|
||||
signature = inspect.signature(fn)
|
||||
self.arg_names = [v.name for v in signature.parameters.values()]
|
||||
|
||||
@@ -5,48 +5,48 @@ import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import os
|
||||
import subprocess
|
||||
import textwrap
|
||||
from collections import defaultdict, namedtuple
|
||||
from typing import (Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast,
|
||||
overload)
|
||||
from functools import cached_property
|
||||
from typing import Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast, overload
|
||||
|
||||
from .._C.libtriton.triton import TMAInfos
|
||||
from ..common.backend import get_backend, path_to_ptxas
|
||||
from ..language.core import dtype
|
||||
from ..common.backend import get_backend, get_cuda_version_key
|
||||
from .interpreter import InterpretedFunction
|
||||
|
||||
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
TRITON_VERSION = "2.1.0"
|
||||
|
||||
|
||||
def get_cuda_stream(idx=None):
|
||||
if idx is None:
|
||||
idx = get_current_device()
|
||||
try:
|
||||
from torch._C import _cuda_getCurrentRawStream
|
||||
|
||||
return _cuda_getCurrentRawStream(idx)
|
||||
except ImportError:
|
||||
import torch
|
||||
|
||||
return torch.cuda.current_stream(idx).cuda_stream
|
||||
|
||||
|
||||
def get_current_device():
|
||||
import torch
|
||||
|
||||
return torch.cuda.current_device()
|
||||
|
||||
|
||||
def set_current_device(idx):
|
||||
import torch
|
||||
|
||||
torch.cuda.set_device(idx)
|
||||
|
||||
|
||||
def get_device_capability(idx):
|
||||
import torch
|
||||
|
||||
return torch.cuda.get_device_capability(idx)
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
T = TypeVar("T")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Dependencies Finder
|
||||
@@ -72,7 +72,8 @@ class DependenciesFinder(ast.NodeVisitor):
|
||||
lhs = self.visit(node.value)
|
||||
while isinstance(lhs, ast.Attribute):
|
||||
lhs = self.visit(lhs.value)
|
||||
if lhs is None or (getattr(lhs, "__name__", "") == "triton" or getattr(lhs, "__name__", "").endswith(".triton")):
|
||||
if lhs is None or (getattr(lhs, "__name__", "") == "triton"
|
||||
or getattr(lhs, "__name__", "").endswith(".triton")):
|
||||
return None
|
||||
return getattr(lhs, node.attr)
|
||||
|
||||
@@ -82,55 +83,26 @@ class DependenciesFinder(ast.NodeVisitor):
|
||||
return
|
||||
if inspect.isbuiltin(func):
|
||||
return
|
||||
if func.__module__ and (func.__module__.startswith('triton.') or '.triton.' in func.__module__):
|
||||
if func.__module__ and (func.__module__.startswith("triton.") or ".triton." in func.__module__):
|
||||
return
|
||||
assert isinstance(func, JITFunction), f"Function \"{func.__name__}\" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this"
|
||||
assert isinstance(
|
||||
func, JITFunction
|
||||
), f'Function "{func.__name__}" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this'
|
||||
if func.hash is None:
|
||||
tree = ast.parse(func.src)
|
||||
finder = DependenciesFinder(func.__globals__, func.src)
|
||||
finder.visit(tree)
|
||||
func.hash = finder.ret
|
||||
noinline = str(getattr(func, 'noinline', False))
|
||||
noinline = str(getattr(func, "noinline", False))
|
||||
self.ret = (self.ret + func.hash + noinline).encode("utf-8")
|
||||
self.ret = hashlib.sha1(self.ret).hexdigest()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# JITFunction
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def version_key():
|
||||
import pkgutil
|
||||
contents = []
|
||||
# frontend
|
||||
with open(__file__, "rb") as f:
|
||||
contents += [hashlib.sha1(f.read()).hexdigest()]
|
||||
# compiler
|
||||
compiler_path = os.path.join(TRITON_PATH, 'compiler')
|
||||
for lib in pkgutil.iter_modules([compiler_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.sha1(f.read()).hexdigest()]
|
||||
# backend
|
||||
libtriton_hash = hashlib.sha1()
|
||||
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
|
||||
while True:
|
||||
chunk = f.read(1024 ** 2)
|
||||
if not chunk:
|
||||
break
|
||||
libtriton_hash.update(chunk)
|
||||
contents.append(libtriton_hash.hexdigest())
|
||||
# language
|
||||
language_path = os.path.join(TRITON_PATH, 'language')
|
||||
for lib in pkgutil.iter_modules([language_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.sha1(f.read()).hexdigest()]
|
||||
# ptxas version
|
||||
ptxas = path_to_ptxas()[0]
|
||||
ptxas_version = hashlib.sha1(subprocess.check_output([ptxas, "--version"])).hexdigest()
|
||||
return '-'.join(TRITON_VERSION) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||
|
||||
|
||||
def _normalize_ty(ty) -> str:
|
||||
if isinstance(ty, type):
|
||||
return ty.__name__
|
||||
@@ -139,6 +111,85 @@ def _normalize_ty(ty) -> str:
|
||||
return repr(ty)
|
||||
|
||||
|
||||
class KernelParam:
|
||||
"""Represents a parameter to a @jit'ed function.
|
||||
|
||||
A parameter is just the name plus metadata; a parameter plus a value is a
|
||||
KernelArg.
|
||||
"""
|
||||
|
||||
def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool):
|
||||
self.num = num
|
||||
self._param = param
|
||||
self.do_not_specialize = do_not_specialize
|
||||
|
||||
@cached_property
|
||||
def name(self):
|
||||
return self._param.name
|
||||
|
||||
@cached_property
|
||||
def annotation(self):
|
||||
if not self._param.annotation or self._param.annotation == inspect.Parameter.empty:
|
||||
return ""
|
||||
return _normalize_ty(self._param.annotation)
|
||||
|
||||
@cached_property
|
||||
def is_constexpr(self):
|
||||
return "constexpr" in self.annotation
|
||||
|
||||
@property
|
||||
def default(self):
|
||||
return self._param.default
|
||||
|
||||
@property
|
||||
def has_default(self):
|
||||
return self._param.default != inspect.Parameter.empty
|
||||
|
||||
|
||||
class KernelArg:
|
||||
"""Represents an argument to a @jit'ed function.
|
||||
|
||||
An argument is a parameter plus a value.
|
||||
"""
|
||||
|
||||
def __init__(self, value, param):
|
||||
self.value = value
|
||||
self.param = param
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.param.name
|
||||
|
||||
def signature_key(self):
|
||||
annotation = self.param.annotation
|
||||
if "Tensor" in annotation:
|
||||
return self.value.dtype
|
||||
elif annotation == "bool":
|
||||
return "i1"
|
||||
elif annotation == "float":
|
||||
return "fp32"
|
||||
else:
|
||||
return JITFunction._key_of(self.value)
|
||||
|
||||
def specialization_key(self):
|
||||
assert not self.param.do_not_specialize
|
||||
|
||||
try:
|
||||
return (self.value.data_ptr() % JITFunction.divisibility == 0, )
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if isinstance(self.value, int):
|
||||
# bool is a subclass of int, so we don't check explicitly above.
|
||||
return (
|
||||
self.value % JITFunction.divisibility == 0,
|
||||
self.value % JITFunction.divisibility_8 == 0,
|
||||
self.value == 1,
|
||||
)
|
||||
|
||||
return (False, )
|
||||
|
||||
|
||||
class KernelInterface(Generic[T]):
|
||||
run: T
|
||||
|
||||
@@ -152,7 +203,6 @@ class KernelInterface(Generic[T]):
|
||||
|
||||
|
||||
class JITFunction(KernelInterface[T]):
|
||||
|
||||
# Hook for inspecting compiled functions and modules
|
||||
cache_hook = None
|
||||
divisibility = 16
|
||||
@@ -169,44 +219,44 @@ class JITFunction(KernelInterface[T]):
|
||||
elif isinstance(arg, bool):
|
||||
return "i1"
|
||||
elif isinstance(arg, int):
|
||||
if -2**31 <= arg and arg <= 2**31 - 1:
|
||||
if -(2**31) <= arg and arg <= 2**31 - 1:
|
||||
return "i32"
|
||||
elif 2**63 <= arg and arg <= 2**64 - 1:
|
||||
return "u64"
|
||||
else:
|
||||
return "i64"
|
||||
elif isinstance(arg, float):
|
||||
return 'fp32'
|
||||
return "fp32"
|
||||
elif arg is None:
|
||||
return None
|
||||
else:
|
||||
raise TypeError(f'Unsupported type {type(arg)} for {arg}')
|
||||
raise TypeError(f"Unsupported type {type(arg)} for {arg}")
|
||||
|
||||
@staticmethod
|
||||
def _device_of(arg):
|
||||
if hasattr(arg, "device"):
|
||||
if hasattr(arg.device, 'type'):
|
||||
return arg.device.type
|
||||
|
||||
return ''
|
||||
try:
|
||||
return arg.device.type
|
||||
except AttributeError:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _pinned_memory_of(arg):
|
||||
if hasattr(arg, "is_pinned"):
|
||||
if isinstance(arg.is_pinned, Callable):
|
||||
return arg.is_pinned()
|
||||
|
||||
return False
|
||||
try:
|
||||
return arg.is_pinned()
|
||||
except (AttributeError, TypeError):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _spec_of(arg):
|
||||
if hasattr(arg, "data_ptr"):
|
||||
return (arg.data_ptr() % JITFunction.divisibility == 0)
|
||||
return arg.data_ptr() % JITFunction.divisibility == 0
|
||||
elif isinstance(arg, int):
|
||||
return (arg % 16 == 0, arg == 1)
|
||||
return (arg is None, )
|
||||
|
||||
# TODO(jlebar): Fold this into the KernelArg class.
|
||||
def _get_config(self, *args):
|
||||
|
||||
def is_divisible_by_16(x):
|
||||
if hasattr(x, "data_ptr"):
|
||||
return x.data_ptr() % JITFunction.divisibility == 0
|
||||
@@ -222,28 +272,38 @@ class JITFunction(KernelInterface[T]):
|
||||
if x is None:
|
||||
return True
|
||||
return False
|
||||
divisible_by_16 = {i for i, arg in enumerate(
|
||||
args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
|
||||
divisible_by_8 = {i for i, arg in enumerate(
|
||||
args) if is_divisible_by_8(arg) and i not in self.do_not_specialize}
|
||||
|
||||
divisible_by_16 = {
|
||||
param.num
|
||||
for param, arg in zip(self.params, args)
|
||||
if is_divisible_by_16(arg) and not param.do_not_specialize
|
||||
}
|
||||
divisible_by_8 = {
|
||||
param.num
|
||||
for param, arg in zip(self.params, args)
|
||||
if is_divisible_by_8(arg) and not param.do_not_specialize
|
||||
}
|
||||
equal_to_1 = {
|
||||
i for i, arg in enumerate(args) if isinstance(
|
||||
arg, int) and not isinstance(
|
||||
arg, bool) and arg == 1 and i not in self.do_not_specialize}
|
||||
param.num
|
||||
for param, arg in zip(self.params, args)
|
||||
if isinstance(arg, int) and not isinstance(arg, bool) and arg == 1 and not param.do_not_specialize
|
||||
}
|
||||
# folded equal_to_1 and None
|
||||
# TODO: method to collect all folded args
|
||||
none_args = {i for i, arg in enumerate(args) if arg is None and i not in self.do_not_specialize}
|
||||
none_args = {param.num for param, arg in zip(self.params, args) if arg is None and not param.do_not_specialize}
|
||||
ids_of_folded_args = equal_to_1 | none_args
|
||||
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])(
|
||||
tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args), tuple(divisible_by_8))
|
||||
return namedtuple("instance_descriptor",
|
||||
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])( #
|
||||
tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args),
|
||||
tuple(divisible_by_8))
|
||||
# return _triton.code_gen.instance_descriptor(divisible_by_16,
|
||||
# equal_to_1)
|
||||
|
||||
@staticmethod
|
||||
def _type_of(key):
|
||||
# None are nullptr -- implicitly converted to *i8
|
||||
# `None` is nullptr. Implicitly convert to *i8.
|
||||
if key is None:
|
||||
return '*i8'
|
||||
return "*i8"
|
||||
dtype_str = str(key).split(".")[-1]
|
||||
tys = {
|
||||
"bool": "i1",
|
||||
@@ -281,21 +341,46 @@ class JITFunction(KernelInterface[T]):
|
||||
constants = dict(zip(self.constexprs, constexpr_key))
|
||||
return constants
|
||||
|
||||
<<<<<<< HEAD
|
||||
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization,enable_fp_fusion, extern_libs, configs):
|
||||
=======
|
||||
def _call_hook(
|
||||
self,
|
||||
key,
|
||||
signature,
|
||||
device,
|
||||
constants,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
num_stages,
|
||||
enable_warp_specialization,
|
||||
enable_fp_fusion,
|
||||
extern_libs,
|
||||
configs,
|
||||
):
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
if JITFunction.cache_hook is None:
|
||||
return False
|
||||
|
||||
name = self.fn.__name__
|
||||
module = self.fn.__module__
|
||||
<<<<<<< HEAD
|
||||
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
|
||||
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, matrix_instr_nonkdim={matrix_instr_nonkdim}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs}), enable_fp_fusion={enable_fp_fusion}]({arg_reprs})"
|
||||
=======
|
||||
arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])])
|
||||
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, enable_warp_specialization={enable_warp_specialization}, enable_fp_fusion={enable_fp_fusion}]({arg_reprs})"
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
key = str(key)
|
||||
|
||||
class LegacyCompiler:
|
||||
|
||||
def __init__(self, module, name):
|
||||
self.module = module
|
||||
self.name = name
|
||||
pass
|
||||
|
||||
<<<<<<< HEAD
|
||||
kwargs = dict(signature=signature, device=device, constants=constants,
|
||||
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs,
|
||||
configs=configs)
|
||||
@@ -326,18 +411,43 @@ class JITFunction(KernelInterface[T]):
|
||||
return 'fp32'
|
||||
else:
|
||||
return self._key_of(arg)
|
||||
=======
|
||||
kwargs = dict(
|
||||
signature=signature,
|
||||
device=device,
|
||||
constants=constants,
|
||||
num_warps=num_warps,
|
||||
num_ctas=num_ctas,
|
||||
num_stages=num_stages,
|
||||
enable_warp_specialization=enable_warp_specialization,
|
||||
enable_fp_fusion=enable_fp_fusion,
|
||||
extern_libs=extern_libs,
|
||||
configs=configs,
|
||||
)
|
||||
|
||||
return JITFunction.cache_hook(
|
||||
key=key,
|
||||
repr=repr,
|
||||
fn=LegacyCompiler(module, name),
|
||||
compile={"key": key, **kwargs},
|
||||
is_manual_warmup=False,
|
||||
already_compiled=False,
|
||||
)
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
|
||||
def _conclude_device_type(self, device_types: List[str], pinned_memory_flags: List[bool]) -> str:
|
||||
device_types = [device_type for device_type in device_types if device_type != '']
|
||||
device_types = [device_type for device_type in device_types if device_type != ""]
|
||||
# Return cuda if one of the input tensors is cuda
|
||||
if 'cuda' in device_types:
|
||||
if "cuda" in device_types:
|
||||
import torch
|
||||
return 'hip' if torch.version.hip else 'cuda'
|
||||
|
||||
is_cpu = all(device_type == 'cpu' for device_type in device_types)
|
||||
return "hip" if torch.version.hip else "cuda"
|
||||
|
||||
is_cpu = all(device_type == "cpu" for device_type in device_types)
|
||||
is_pinned_memory = any(pinned_memory_flag for pinned_memory_flag in pinned_memory_flags)
|
||||
# Return cuda if all the input tensors are cpu while the memory is pinned
|
||||
if is_cpu and is_pinned_memory:
|
||||
<<<<<<< HEAD
|
||||
return 'cuda'
|
||||
|
||||
return device_types[0] if len(device_types) > 0 else 'cuda'
|
||||
@@ -452,16 +562,193 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
|
||||
scope = {"launcher_body": launcher_body}
|
||||
exec(src, scope)
|
||||
return scope[self.fn.__name__]
|
||||
=======
|
||||
return "cuda"
|
||||
|
||||
return device_types[0] if len(device_types) > 0 else "cuda"
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
from ..compiler import CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps
|
||||
|
||||
# Get a compiler-flags arg like `num_warps` and remove it from kwargs.
|
||||
def get_special_arg(name: str, default=None):
|
||||
if name not in kwargs:
|
||||
return default
|
||||
ret = kwargs[name]
|
||||
del kwargs[name]
|
||||
return ret
|
||||
|
||||
grid = get_special_arg("grid")
|
||||
num_warps = get_special_arg("num_warps")
|
||||
num_ctas = get_special_arg("num_ctas", 1)
|
||||
num_stages = get_special_arg("num_stages")
|
||||
enable_warp_specialization = get_special_arg("enable_warp_specialization", False)
|
||||
enable_fp_fusion = get_special_arg("enable_fp_fusion", True)
|
||||
extern_libs = get_special_arg("extern_libs")
|
||||
stream = get_special_arg("stream")
|
||||
warmup = get_special_arg("warmup", False)
|
||||
device = get_special_arg("device")
|
||||
device_type = get_special_arg("device_type")
|
||||
|
||||
# Bind the remaining arguments to `fn`.
|
||||
bound_args = self.signature.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
|
||||
assert len(bound_args.arguments) == len(self.params)
|
||||
args = [KernelArg(arg_value, param) for (_, arg_value), param in zip(bound_args.arguments.items(), self.params)]
|
||||
|
||||
non_constexpr_arg_values = [arg.value for arg in args if not arg.param.is_constexpr]
|
||||
|
||||
sig_key = tuple(arg.signature_key() for arg in args if not arg.param.is_constexpr)
|
||||
spec_key = tuple(arg.specialization_key() for arg in args if not arg.param.do_not_specialize)
|
||||
constexpr_key = tuple(arg.value for arg in args if arg.param.is_constexpr)
|
||||
|
||||
assert num_ctas > 0
|
||||
assert grid is not None
|
||||
if callable(grid):
|
||||
# Arguments are passed as a dict to `grid`, by contract.
|
||||
# TODO(jlebar): In the new launch API, pass the compiler flags as a
|
||||
# second parameter to `grid`.
|
||||
grid = grid(dict(bound_args.arguments))
|
||||
grid_size = len(grid)
|
||||
grid_0 = grid[0]
|
||||
grid_1 = grid[1] if grid_size > 1 else 1
|
||||
grid_2 = grid[2] if grid_size > 2 else 1
|
||||
if device_type is None:
|
||||
device_types = [self._device_of(arg) for arg in non_constexpr_arg_values]
|
||||
device_types = [_device_type for _device_type in device_types if _device_type != ""]
|
||||
device_type = self._conclude_device_type(device_types,
|
||||
[self._pinned_memory_of(arg) for arg in non_constexpr_arg_values])
|
||||
|
||||
device_backend = None
|
||||
if device_type not in ["cuda"]:
|
||||
device_backend = get_backend(device_type)
|
||||
if device_backend is None:
|
||||
raise ValueError("Cannot find backend for " + device_type)
|
||||
|
||||
if device is None:
|
||||
if device_type in ["cuda"]:
|
||||
device = get_current_device()
|
||||
set_current_device(device)
|
||||
else:
|
||||
device = device_backend.get_current_device()
|
||||
device_backend.set_current_device(device)
|
||||
if stream is None and not warmup:
|
||||
if device_type in ["cuda"]:
|
||||
stream = get_cuda_stream(device)
|
||||
else:
|
||||
stream = device_backend.get_stream()
|
||||
|
||||
if num_warps is None:
|
||||
num_warps = get_arch_default_num_warps(device_type)
|
||||
if num_stages is None:
|
||||
num_stages = get_arch_default_num_stages(device_type)
|
||||
|
||||
if device_type in ["cuda"]:
|
||||
version_key = get_cuda_version_key()
|
||||
else:
|
||||
version_key = device_backend.get_version_key()
|
||||
key = (
|
||||
version_key,
|
||||
sig_key,
|
||||
constexpr_key,
|
||||
spec_key,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
num_stages,
|
||||
enable_warp_specialization,
|
||||
enable_fp_fusion,
|
||||
self.debug,
|
||||
)
|
||||
if extern_libs is not None:
|
||||
key = (key, tuple(extern_libs.items()))
|
||||
|
||||
# Kernel is not cached; we have to compile.
|
||||
if key not in self.cache[device]:
|
||||
configs = (self._get_config(*[arg.value for arg in args]), )
|
||||
constants = {
|
||||
arg.param.num: arg.value
|
||||
for arg in args
|
||||
if arg.param.is_constexpr or arg.param.num in configs[0].equal_to_1 or arg.value is None
|
||||
}
|
||||
for i, arg in constants.items():
|
||||
if callable(arg):
|
||||
raise TypeError(f"Callable constexpr at index {i} is not supported")
|
||||
|
||||
# Build kernel signature -- doesn't include constexpr arguments.
|
||||
signature = {
|
||||
arg.param.num: self._type_of(self._key_of(arg.value))
|
||||
for arg in args
|
||||
if not arg.param.is_constexpr
|
||||
}
|
||||
|
||||
if self._call_hook(
|
||||
key,
|
||||
signature,
|
||||
device,
|
||||
constants,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
num_stages,
|
||||
enable_warp_specialization,
|
||||
enable_fp_fusion,
|
||||
extern_libs,
|
||||
configs,
|
||||
):
|
||||
return None
|
||||
|
||||
self.cache[device][key] = compile(
|
||||
self,
|
||||
signature=signature,
|
||||
device=device,
|
||||
constants=constants,
|
||||
num_warps=num_warps,
|
||||
num_ctas=num_ctas,
|
||||
num_stages=num_stages,
|
||||
enable_warp_specialization=enable_warp_specialization,
|
||||
enable_fp_fusion=enable_fp_fusion,
|
||||
extern_libs=extern_libs,
|
||||
configs=configs,
|
||||
debug=self.debug,
|
||||
device_type=device_type,
|
||||
)
|
||||
|
||||
bin = self.cache[device][key]
|
||||
if not warmup:
|
||||
bin.c_wrapper(
|
||||
grid_0,
|
||||
grid_1,
|
||||
grid_2,
|
||||
bin.num_warps,
|
||||
bin.num_ctas,
|
||||
bin.clusterDims[0],
|
||||
bin.clusterDims[1],
|
||||
bin.clusterDims[2],
|
||||
bin.shared,
|
||||
stream,
|
||||
bin.cu_function,
|
||||
CompiledKernel.launch_enter_hook,
|
||||
CompiledKernel.launch_exit_hook,
|
||||
bin,
|
||||
*bin.assemble_tensormap_to_arg(non_constexpr_arg_values),
|
||||
)
|
||||
return bin
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
|
||||
def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None):
|
||||
do_not_specialize = do_not_specialize if do_not_specialize else []
|
||||
|
||||
self.fn = fn
|
||||
self.module = fn.__module__
|
||||
self.version = version
|
||||
# function signature information
|
||||
signature = inspect.signature(fn)
|
||||
self.arg_names = [v.name for v in signature.parameters.values()]
|
||||
self.arg_defaults = [v.default for v in signature.parameters.values()]
|
||||
self.has_defaults = any(v != inspect._empty for v in self.arg_defaults)
|
||||
self.signature = inspect.signature(fn)
|
||||
self.do_not_specialize = do_not_specialize
|
||||
|
||||
self.params = []
|
||||
for i, param in enumerate(self.signature.parameters.values()):
|
||||
dns = do_not_specialize and (i in do_not_specialize or param.name in do_not_specialize)
|
||||
self.params.append(KernelParam(i, param, dns))
|
||||
|
||||
# function source code (without decorators)
|
||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||
self.src = self.src[self.src.find("def"):]
|
||||
@@ -470,22 +757,18 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
|
||||
self.hash = None
|
||||
# JITFunction can be instantiated as kernel
|
||||
# when called with a grid using __getitem__
|
||||
self.kernel_decorators = []
|
||||
self.kernel = None
|
||||
self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug
|
||||
self.noinline = noinline
|
||||
# annotations
|
||||
self.__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
|
||||
# index of constexprs
|
||||
self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if 'constexpr' in ty]
|
||||
# specialization hints
|
||||
regular_args = [arg for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
|
||||
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
|
||||
self.do_not_specialize = {regular_args.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize}
|
||||
|
||||
# tma info
|
||||
self.tensormaps_info = TMAInfos()
|
||||
# launcher
|
||||
self.run = self._make_launcher()
|
||||
|
||||
# TODO(jlebar): Remove uses of these fields outside this file, then
|
||||
# remove the fields here.
|
||||
self.arg_names = [p.name for p in self.params]
|
||||
self.constexprs = [p.num for p in self.params if p.is_constexpr]
|
||||
|
||||
# re-use docs of wrapped function
|
||||
self.__doc__ = fn.__doc__
|
||||
self.__name__ = fn.__name__
|
||||
@@ -498,7 +781,7 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
|
||||
if self.hash is None:
|
||||
dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src)
|
||||
dependencies_finder.visit(self.parse())
|
||||
self.hash = dependencies_finder.ret + version_key()
|
||||
self.hash = dependencies_finder.ret
|
||||
return self.hash
|
||||
|
||||
def warmup(self, *args, **kwargs):
|
||||
@@ -518,14 +801,10 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
|
||||
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
# - when kernel decorators change, cached kernel
|
||||
# needs to be cleared
|
||||
if name == 'kernel_decorators':
|
||||
self.kernel = None
|
||||
super(JITFunction, self).__setattr__(name, value)
|
||||
# - when `.src` attribute is set, cache path needs
|
||||
# to be reinitialized
|
||||
if name == 'src':
|
||||
if name == "src":
|
||||
self.hash = None
|
||||
|
||||
def __repr__(self):
|
||||
@@ -591,12 +870,14 @@ def jit(
|
||||
debug=debug,
|
||||
noinline=noinline,
|
||||
)
|
||||
|
||||
if fn is not None:
|
||||
return decorator(fn)
|
||||
|
||||
else:
|
||||
return decorator
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Utilities for mocking tensors
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -607,10 +888,10 @@ class MockTensor:
|
||||
Can be used in place of real tensors when calling:
|
||||
kernel.warmup(MockTensor(torch.float32), ...)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def wrap_dtype(arg):
|
||||
if arg.__class__.__name__ == "dtype" and\
|
||||
arg.__module__ == "torch":
|
||||
if arg.__class__.__name__ == "dtype" and arg.__module__ == "torch":
|
||||
return MockTensor(arg)
|
||||
return arg
|
||||
|
||||
@@ -623,6 +904,7 @@ class MockTensor:
|
||||
|
||||
|
||||
class TensorWrapper:
|
||||
|
||||
def __init__(self, base, dtype):
|
||||
self.dtype = dtype
|
||||
self.base = base
|
||||
@@ -637,7 +919,7 @@ class TensorWrapper:
|
||||
return self.base.stride(i)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'TensorWrapper[{self.dtype}]({self.base})'
|
||||
return f"TensorWrapper[{self.dtype}]({self.base})"
|
||||
|
||||
def element_size(self):
|
||||
return self.base.element_size()
|
||||
@@ -655,4 +937,4 @@ def reinterpret(tensor, dtype):
|
||||
# A new wrapper is needed around an unwrapped tensor.
|
||||
return TensorWrapper(tensor, dtype)
|
||||
else:
|
||||
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')
|
||||
raise TypeError(f"Cannot reinterpret a {type(tensor)}.")
|
||||
|
||||
@@ -78,10 +78,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None):
|
||||
return torch.mean(torch.tensor(ret)).item()
|
||||
|
||||
|
||||
def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
|
||||
quantiles=None,
|
||||
fast_flush=True,
|
||||
return_mode="mean"):
|
||||
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean"):
|
||||
assert return_mode in ["min", "max", "mean", "median"]
|
||||
import torch
|
||||
"""
|
||||
@@ -261,11 +258,12 @@ class Benchmark:
|
||||
|
||||
|
||||
class Mark:
|
||||
|
||||
def __init__(self, fn, benchmarks):
|
||||
self.fn = fn
|
||||
self.benchmarks = benchmarks
|
||||
|
||||
def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, **kwrags):
|
||||
def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False, **kwrags):
|
||||
import os
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
@@ -321,24 +319,36 @@ class Mark:
|
||||
if save_path:
|
||||
plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
|
||||
df = df[x_names + bench.line_names]
|
||||
if diff_col and df.shape[1] == 2:
|
||||
col0, col1 = df.columns.tolist()
|
||||
df['Diff'] = df[col1] - df[col0]
|
||||
|
||||
if print_data:
|
||||
print(bench.plot_name + ':')
|
||||
print(df)
|
||||
if save_path:
|
||||
df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False)
|
||||
return df
|
||||
|
||||
def run(self, show_plots=False, print_data=False, save_path='', **kwargs):
|
||||
def run(self, show_plots=False, print_data=False, save_path='', return_df=False, **kwargs):
|
||||
has_single_bench = isinstance(self.benchmarks, Benchmark)
|
||||
benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks
|
||||
result_dfs = []
|
||||
if save_path:
|
||||
html = open(os.path.join(save_path, "results.html"), "w")
|
||||
html.write("<html><body>\n")
|
||||
for bench in benchmarks:
|
||||
self._run(bench, save_path, show_plots, print_data, **kwargs)
|
||||
result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs))
|
||||
if save_path:
|
||||
html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
|
||||
if save_path:
|
||||
html.write("</body></html>\n")
|
||||
if return_df:
|
||||
if has_single_bench:
|
||||
return result_dfs[0]
|
||||
else:
|
||||
return result_dfs
|
||||
return None
|
||||
|
||||
|
||||
def perf_report(benchmarks):
|
||||
@@ -393,12 +403,15 @@ def get_max_tensorcore_tflops(dtype, clock_rate, backend=None, device=None):
|
||||
tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
|
||||
return tflops
|
||||
|
||||
|
||||
# create decorator that wraps test function into
|
||||
# a cuda-memcheck system call
|
||||
|
||||
|
||||
def cuda_memcheck(**target_kwargs):
|
||||
|
||||
def decorator(test_fn):
|
||||
|
||||
@functools.wraps(test_fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
import psutil
|
||||
@@ -416,7 +429,9 @@ def cuda_memcheck(**target_kwargs):
|
||||
assert "ERROR SUMMARY: 0 errors" in str(out.stdout)
|
||||
else:
|
||||
test_fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@@ -424,22 +439,18 @@ def cuda_memcheck(**target_kwargs):
|
||||
def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215):
|
||||
try:
|
||||
subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"])
|
||||
subprocess.check_output(
|
||||
[
|
||||
"nvidia-smi",
|
||||
"-i",
|
||||
"0",
|
||||
f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}",
|
||||
]
|
||||
)
|
||||
subprocess.check_output(
|
||||
[
|
||||
"nvidia-smi",
|
||||
"-i",
|
||||
"0",
|
||||
f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}",
|
||||
]
|
||||
)
|
||||
subprocess.check_output([
|
||||
"nvidia-smi",
|
||||
"-i",
|
||||
"0",
|
||||
f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}",
|
||||
])
|
||||
subprocess.check_output([
|
||||
"nvidia-smi",
|
||||
"-i",
|
||||
"0",
|
||||
f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}",
|
||||
])
|
||||
cur_sm_clock = nvsmi(["clocks.current.sm"])[0]
|
||||
cur_mem_clock = nvsmi(["clocks.current.memory"])[0]
|
||||
assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz"
|
||||
|
||||
@@ -141,8 +141,7 @@ class ExternLibrary(ABC):
|
||||
f.write(file_str)
|
||||
f.close()
|
||||
if self._format:
|
||||
subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file],
|
||||
stdout=subprocess.PIPE).communicate()
|
||||
subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file], stdout=subprocess.PIPE).communicate()
|
||||
subprocess.Popen(["isort", output_file], stdout=subprocess.PIPE).communicate()
|
||||
|
||||
|
||||
@@ -208,56 +207,36 @@ class Libdevice(ExternLibrary):
|
||||
|
||||
# Group functions together by renaming.
|
||||
renaming = {
|
||||
'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh',
|
||||
'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': 'add_rn',
|
||||
'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru',
|
||||
'dadd_rz': 'add_rz', 'fadd_rz': 'add_rz', 'asinf': 'asin',
|
||||
'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2',
|
||||
'atanhf': 'atanh', 'brevll': 'brev', 'cbrtf': 'cbrt',
|
||||
'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign',
|
||||
'cosf': 'cos', 'coshf': 'cosh', 'cospif': 'cospi',
|
||||
'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1',
|
||||
'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn',
|
||||
'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', 'ddiv_ru': 'div_ru',
|
||||
'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf',
|
||||
'erfcf': 'erfc', 'erfcinvf': 'erfcinv', 'erfcxf': 'erfcx',
|
||||
'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10',
|
||||
'exp2f': 'exp2', 'expm1f': 'expm1', 'fabsf': 'abs',
|
||||
'fabs': 'abs', 'fast_fdividef': 'fast_dividef',
|
||||
'fdimf': 'fdim', 'ffsll': 'ffs', 'floorf': 'floor',
|
||||
'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn',
|
||||
'fmaf_ru': 'fma_ru', 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod',
|
||||
'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb',
|
||||
'isinff': 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan',
|
||||
'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn',
|
||||
'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint',
|
||||
'llroundf': 'llround', 'logf': 'log', 'log10f': 'log10',
|
||||
'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb',
|
||||
'umax': 'max', 'llmax': 'max', 'ullmax': 'max', 'fmaxf': 'max',
|
||||
'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min',
|
||||
'fminf': 'min', 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd',
|
||||
'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', 'dmul_ru': 'mul_ru',
|
||||
'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz',
|
||||
'umul24': 'mul24', 'umulhi': 'mulhi', 'mul64hi': 'mulhi',
|
||||
'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': 'nextafter',
|
||||
'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf',
|
||||
'normcdfinvf': 'normcdfinv', 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow',
|
||||
'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', 'drcp_rd': 'rcp_rd',
|
||||
'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru',
|
||||
'drcp_ru': 'rcp_ru', 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz',
|
||||
'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot',
|
||||
'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d',
|
||||
'roundf': 'round', 'rsqrtf': 'rsqrt', 'frsqrt_rn': 'rsqrt_rn',
|
||||
'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit',
|
||||
'signbitd': 'signbit', 'sinf': 'sin', 'sinhf': 'sinh',
|
||||
'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd',
|
||||
'dsqrt_rd': 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn',
|
||||
'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', 'fsqrt_rz': 'sqrt_rz',
|
||||
'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd',
|
||||
'fsub_rn': 'sub_rn', 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru',
|
||||
'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz',
|
||||
'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc',
|
||||
'y0f': 'y0', 'y1f': 'y1', 'ynf': 'yn'
|
||||
'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh', 'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn':
|
||||
'add_rn', 'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru', 'dadd_rz': 'add_rz', 'fadd_rz':
|
||||
'add_rz', 'asinf': 'asin', 'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2', 'atanhf': 'atanh',
|
||||
'brevll': 'brev', 'cbrtf': 'cbrt', 'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign', 'cosf': 'cos',
|
||||
'coshf': 'cosh', 'cospif': 'cospi', 'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1',
|
||||
'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn', 'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru',
|
||||
'ddiv_ru': 'div_ru', 'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf', 'erfcf': 'erfc', 'erfcinvf':
|
||||
'erfcinv', 'erfcxf': 'erfcx', 'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10', 'exp2f': 'exp2',
|
||||
'expm1f': 'expm1', 'fabsf': 'abs', 'fabs': 'abs', 'fast_fdividef': 'fast_dividef', 'fdimf': 'fdim', 'ffsll':
|
||||
'ffs', 'floorf': 'floor', 'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn', 'fmaf_ru': 'fma_ru',
|
||||
'fmaf_rz': 'fma_rz', 'fmodf': 'fmod', 'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb', 'isinff':
|
||||
'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan', 'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn',
|
||||
'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint', 'llroundf': 'llround', 'logf': 'log', 'log10f':
|
||||
'log10', 'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb', 'umax': 'max', 'llmax': 'max', 'ullmax':
|
||||
'max', 'fmaxf': 'max', 'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min', 'fminf': 'min',
|
||||
'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd', 'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn',
|
||||
'dmul_ru': 'mul_ru', 'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz', 'umul24': 'mul24',
|
||||
'umulhi': 'mulhi', 'mul64hi': 'mulhi', 'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf':
|
||||
'nextafter', 'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf', 'normcdfinvf': 'normcdfinv',
|
||||
'popcll': 'popc', 'powif': 'pow', 'powi': 'pow', 'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd',
|
||||
'drcp_rd': 'rcp_rd', 'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru', 'drcp_ru': 'rcp_ru',
|
||||
'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz', 'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot',
|
||||
'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d', 'roundf': 'round', 'rsqrtf': 'rsqrt',
|
||||
'frsqrt_rn': 'rsqrt_rn', 'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit', 'signbitd': 'signbit',
|
||||
'sinf': 'sin', 'sinhf': 'sinh', 'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd', 'dsqrt_rd':
|
||||
'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn', 'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru',
|
||||
'fsqrt_rz': 'sqrt_rz', 'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd', 'fsub_rn': 'sub_rn',
|
||||
'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru', 'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz',
|
||||
'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc', 'y0f': 'y0', 'y1f': 'y1', 'ynf':
|
||||
'yn'
|
||||
}
|
||||
|
||||
for symbol in self._symbols.values():
|
||||
@@ -347,8 +326,7 @@ class LLVMDisassembler:
|
||||
self._ll_file = "/tmp/extern_lib.ll"
|
||||
|
||||
def disasm(self, lib_path: str) -> None:
|
||||
subprocess.Popen([self._path, lib_path, "-o", self.ll_file],
|
||||
stdout=subprocess.PIPE).communicate()
|
||||
subprocess.Popen([self._path, lib_path, "-o", self.ll_file], stdout=subprocess.PIPE).communicate()
|
||||
|
||||
@property
|
||||
def ll_file(self) -> str:
|
||||
|
||||
@@ -40,10 +40,13 @@ if __name__ == "__main__":
|
||||
|
||||
# command-line arguments
|
||||
parser = ArgumentParser(description=desc)
|
||||
parser.add_argument("path", help="Path to Python source containing desired kernel in its scope. File will be executed.")
|
||||
parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile", required=True)
|
||||
parser.add_argument("path",
|
||||
help="Path to Python source containing desired kernel in its scope. File will be executed.")
|
||||
parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile",
|
||||
required=True)
|
||||
parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel")
|
||||
parser.add_argument("--num-stages", "-ns", type=int, default=3, help="Number of stages (meta-parameter of the kernel)")
|
||||
parser.add_argument("--num-stages", "-ns", type=int, default=3,
|
||||
help="Number of stages (meta-parameter of the kernel)")
|
||||
parser.add_argument("--out-name", "-on", type=str, default=None, help="Out name for the compiled kernel")
|
||||
parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename")
|
||||
parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True)
|
||||
@@ -104,7 +107,8 @@ if __name__ == "__main__":
|
||||
config = triton.compiler.instance_descriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1)
|
||||
for i in equal_to_1:
|
||||
constexprs.update({i: 1})
|
||||
ccinfo = triton.compile(kernel, signature=signature, constants=constexprs, configs=[config], num_warps=args.num_warps, num_stages=args.num_stages)
|
||||
ccinfo = triton.compile(kernel, signature=signature, constants=constexprs, configs=[config],
|
||||
num_warps=args.num_warps, num_stages=args.num_stages)
|
||||
arg_names = []
|
||||
arg_types = []
|
||||
for i in signature.keys():
|
||||
|
||||
@@ -27,6 +27,7 @@ class KernelLinkerMeta:
|
||||
|
||||
|
||||
class HeaderParser:
|
||||
|
||||
def __init__(self) -> None:
|
||||
import re
|
||||
|
||||
@@ -42,7 +43,6 @@ class HeaderParser:
|
||||
self.kernels = defaultdict(list)
|
||||
|
||||
def extract_linker_meta(self, header: str):
|
||||
|
||||
for ln in header.splitlines():
|
||||
if ln.startswith("//"):
|
||||
m = self.linker_directives.match(ln)
|
||||
@@ -76,7 +76,7 @@ class HeaderParser:
|
||||
m = self.c_sig.findall(c_sig)
|
||||
if len(m):
|
||||
tys, args = [], []
|
||||
for (ty, arg_name) in m:
|
||||
for ty, arg_name in m:
|
||||
tys.append(ty)
|
||||
args.append(arg_name)
|
||||
return tys, args
|
||||
@@ -84,7 +84,7 @@ class HeaderParser:
|
||||
raise LinkerError(f"{c_sig} is not a valid argument signature")
|
||||
|
||||
def _match_suffix(self, suffix: str, c_sig: str):
|
||||
args = c_sig.split(',')
|
||||
args = c_sig.split(",")
|
||||
s2i = {"c": 1, "d": 16}
|
||||
num_specs = 0
|
||||
sizes = []
|
||||
@@ -110,7 +110,7 @@ class HeaderParser:
|
||||
if name in self.kernels:
|
||||
last: KernelLinkerMeta = self.kernels[name][-1]
|
||||
|
||||
for (cur, new_) in zip(last.arg_ctypes, ker.arg_ctypes):
|
||||
for cur, new_ in zip(last.arg_ctypes, ker.arg_ctypes):
|
||||
if cur != new_:
|
||||
raise LinkerError(
|
||||
f"Mismatched signature for kernel {name}: \n\texisting sig is: {','.join(last.arg_ctypes)}\n\tcurrent is: {','.join(ker.arg_ctypes)}"
|
||||
@@ -152,7 +152,7 @@ void unload_{meta.orig_kernel_name}();
|
||||
# generate dispatcher function for kernels with different meta-parameter and constant values
|
||||
def make_default_algo_kernel(meta: KernelLinkerMeta) -> str:
|
||||
src = f"CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}){{\n"
|
||||
src += f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n"
|
||||
src += (f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n")
|
||||
src += "}\n"
|
||||
return src
|
||||
|
||||
@@ -164,12 +164,22 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -
|
||||
src += f"CUresult {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, {gen_signature(meta)});\n"
|
||||
src += "\n"
|
||||
|
||||
src += f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{"
|
||||
src += (f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{")
|
||||
src += "\n"
|
||||
for meta in sorted(metas, key=lambda m: -m.num_specs):
|
||||
cond_fn = lambda val, hint: f"({val} % {hint} == 0)" if hint == 16 else f"({val} == {hint})" if hint == 1 else None
|
||||
conds = " && ".join([cond_fn(val, hint) for val, hint in zip(meta.arg_names, meta.sizes) if hint is not None])
|
||||
src += f" if ({conds})\n"
|
||||
cond_fn = ( #
|
||||
lambda val, hint: f"({val} % {hint} == 0)" #
|
||||
if hint == 16 #
|
||||
else f"({val} == {hint})" #
|
||||
if hint == 1 #
|
||||
else None)
|
||||
conds = " && ".join([ #
|
||||
cond_fn(val, hint) #
|
||||
for val, hint in zip(meta.arg_names, meta.sizes) #
|
||||
if hint is not None
|
||||
])
|
||||
src += (f" if ({conds})\n" if any(meta.sizes) else "if (1)\n"
|
||||
) # Edge case where no specializations hence no dispatching required
|
||||
arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1]
|
||||
src += f" return {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(stream, {', '.join(arg_names)});\n"
|
||||
src += "\n"
|
||||
@@ -183,7 +193,7 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -
|
||||
src += f"void {mode}_{name}() {{"
|
||||
src += "\n"
|
||||
for meta in sorted(metas, key=lambda m: -m.num_specs):
|
||||
src += f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n"
|
||||
src += (f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n")
|
||||
src += "}\n"
|
||||
return src
|
||||
|
||||
@@ -252,7 +262,12 @@ if __name__ == "__main__":
|
||||
help="Paths to header files to link. Must include linker directive annotations (autogenerated by ttc)",
|
||||
)
|
||||
parser.add_argument("--out", "-o", type=Path, help="Out filename")
|
||||
parser.add_argument("--prefix", type=str, default="", help="String to prefix kernel dispatcher names")
|
||||
parser.add_argument(
|
||||
"--prefix",
|
||||
type=str,
|
||||
default="",
|
||||
help="String to prefix kernel dispatcher names",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# metadata
|
||||
|
||||
Reference in New Issue
Block a user