Reformat Python code with yapf. (#2589)

I've add an option to yapf to do what we want for long lines, see
https://github.com/google/yapf/pull/1177.  We can now have a real Python
formatter, yay!

To make this PR, I ran my modified yapf over the repository, then looked
over the full diff.  Where yapf was mangling the param list of long
function decls/calls (mostly kernels), I manually added `#` to put
linebreaks where we want.  I fixed up other formatting too -- mostly
adding or removing a trailing comma from lists.

Overall, trailing `#` was sufficient to get formatting similar to our
current code.  I didn't have to disable yapf anywhere.

---------

Co-authored-by: Phil Tillet <phil@openai.com>
This commit is contained in:
Justin Lebar
2023-11-02 20:44:17 -07:00
committed by GitHub
parent dced22c4b7
commit df08301e76
85 changed files with 3802 additions and 3880 deletions

View File

@@ -45,12 +45,12 @@ __all__ = [
"tools",
]
# -------------------------------------
# misc. utilities that don't fit well
# into any specific module
# -------------------------------------
def cdiv(x: int, y: int):
return (x + y - 1) // y

View File

@@ -1,4 +1,3 @@
import functools
import hashlib
import importlib
@@ -16,6 +15,7 @@ TRITON_VERSION = "2.1.0"
class BaseBackend:
def __init__(self, device_type: str) -> None:
self.device_type = device_type
@@ -154,7 +154,7 @@ def compute_core_version_key():
libtriton_hash = hashlib.sha1()
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
while True:
chunk = f.read(1024 ** 2)
chunk = f.read(1024**2)
if not chunk:
break
libtriton_hash.update(chunk)

View File

@@ -86,9 +86,15 @@ def _build(name, src, srcdir):
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
if is_hip():
ret = subprocess.check_call([cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{hip_lib_dir}", "-lamdhip64", "-o", so])
ret = subprocess.check_call([
cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC",
f"-L{hip_lib_dir}", "-lamdhip64", "-o", so
])
else:
cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]
cc_cmd = [
cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda",
"-o", so
]
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
ret = subprocess.check_call(cc_cmd)

View File

@@ -1,5 +1,8 @@
from .compiler import (CompiledKernel, compile, get_arch_default_num_stages,
get_arch_default_num_warps, instance_descriptor)
from .compiler import (CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps,
instance_descriptor)
from .errors import CompilationError
__all__ = ["compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps", "get_arch_default_num_stages"]
__all__ = [
"compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps",
"get_arch_default_num_stages"
]

View File

@@ -10,8 +10,7 @@ from .._C.libtriton.triton import ir
from ..language import constexpr, tensor
# ideally we wouldn't need any runtime component
from ..runtime import JITFunction
from .errors import (CompilationError, CompileTimeAssertionFailure,
UnsupportedLanguageConstruct)
from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
def mangle_ty(ty):
@@ -68,7 +67,10 @@ def _check_fn_args(node, fn, args):
if fn.noinline:
for idx, arg in enumerate(args):
if not _is_constexpr(arg) and not _is_triton_scalar(arg):
raise UnsupportedLanguageConstruct(fn.src, node, f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}')
raise UnsupportedLanguageConstruct(
fn.src, node,
f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}'
)
def _get_fn_file_line(fn):
@@ -89,6 +91,7 @@ _condition_types = {bool, int, type(None)} # Python types accepted for conditio
class enter_sub_region:
def __init__(self, generator):
self.generator = generator
@@ -109,6 +112,7 @@ class enter_sub_region:
# Check if the given syntax node has an "early" return
class ContainsReturnChecker(ast.NodeVisitor):
def __init__(self, gscope):
self.gscope = gscope
@@ -199,9 +203,10 @@ class ContainsReturnChecker(ast.NodeVisitor):
class CodeGenerator(ast.NodeVisitor):
def __init__(self, context, prototype, gscope, attributes, constants, function_name, target,
module=None, is_kernel=False, function_types: Optional[Dict] = None,
debug=False, noinline=False, file_name: Optional[str] = None, begin_line=0):
def __init__(self, context, prototype, gscope, attributes, constants, function_name, target, module=None,
is_kernel=False, function_types: Optional[Dict] = None, debug=False, noinline=False,
file_name: Optional[str] = None, begin_line=0):
self.context = context
self.builder = ir.builder(context)
self.file_name = file_name
@@ -237,8 +242,10 @@ class CodeGenerator(ast.NodeVisitor):
))
def _define_name_lookup(self):
def local_lookup(name: str, absent):
value = self.lscope.get(name, absent) # this needs to be re-fetched from `self` every time, because it gets switched occasionally
# this needs to be re-fetched from `self` every time, because it gets switched occasionally
value = self.lscope.get(name, absent)
if value is not absent and name not in self.local_defs:
self.global_uses[name] = value
return value
@@ -255,8 +262,7 @@ class CodeGenerator(ast.NodeVisitor):
return name_lookup
def set_value(self, name: str,
value: Union[tensor, constexpr]) -> None:
def set_value(self, name: str, value: Union[tensor, constexpr]) -> None:
''' This function:
called by visit_Assign() & visit_FunctionDef() to store left value (lvalue)
1. record local defined name (FIXME: should consider control flow)
@@ -338,7 +344,8 @@ class CodeGenerator(ast.NodeVisitor):
self.visit(init_node)
# initialize function
visibility = "public" if self.is_kernel else "private"
self.fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility, self.noinline)
self.fn = self.builder.get_or_insert_function(self.module, self.function_name,
self.prototype.to_ir(self.builder), visibility, self.noinline)
self.module.push_back(self.fn)
entry = self.fn.add_entry_block()
arg_values = []
@@ -469,12 +476,23 @@ class CodeGenerator(ast.NodeVisitor):
rhs = self.visit(node.right)
method_name = self._method_name_for_bin_op.get(type(node.op))
if method_name is None:
raise UnsupportedLanguageConstruct(None, node, "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__))
raise UnsupportedLanguageConstruct(
None, node, "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__))
return self._apply_binary_method(method_name, lhs, rhs)
_method_name_for_bin_op: Dict[Type[ast.operator], str] = {
ast.Add: '__add__', ast.Sub: '__sub__', ast.Mult: '__mul__', ast.Div: '__truediv__',
ast.FloorDiv: '__floordiv__', ast.Mod: '__mod__', ast.Pow: '__pow__',
ast.LShift: '__lshift__', ast.RShift: '__rshift__', ast.BitAnd: '__and__', ast.BitOr: '__or__', ast.BitXor: '__xor__',
ast.Add: '__add__',
ast.Sub: '__sub__',
ast.Mult: '__mul__',
ast.Div: '__truediv__',
ast.FloorDiv: '__floordiv__',
ast.Mod: '__mod__',
ast.Pow: '__pow__',
ast.LShift: '__lshift__',
ast.RShift: '__rshift__',
ast.BitAnd: '__and__',
ast.BitOr: '__or__',
ast.BitXor: '__xor__',
}
def visit_then_else_blocks(self, node, liveins, then_block, else_block):
@@ -508,7 +526,8 @@ class CodeGenerator(ast.NodeVisitor):
if name in then_defs or name in else_defs:
names.append(name)
ret_types.append(then_defs[name].type if name in then_defs else else_defs[name].type)
ir_ret_types.append(then_defs[name].handle.get_type() if name in then_defs else else_defs[name].handle.get_type())
ir_ret_types.append(then_defs[name].handle.get_type() if name in
then_defs else else_defs[name].handle.get_type())
# variable defined in then but not in else
if name in then_defs and name not in else_defs:
else_defs[name] = liveins[name]
@@ -602,8 +621,7 @@ class CodeGenerator(ast.NodeVisitor):
contains_return = ContainsReturnChecker(self.gscope).visit(node)
if self.scf_stack and contains_return:
raise UnsupportedLanguageConstruct(
None, node,
"Cannot have `return` statements inside `while` or `for` statements in triton "
None, node, "Cannot have `return` statements inside `while` or `for` statements in triton "
"(note that this also applies to `return` statements that are inside functions "
"transitively called from within `while`/`for` statements)")
elif self.scf_stack or not contains_return:
@@ -612,10 +630,13 @@ class CodeGenerator(ast.NodeVisitor):
self.visit_if_top_level(cond, node)
else:
cond = _unwrap_if_constexpr(cond)
if type(cond) not in _condition_types: # not isinstance - we insist the real thing, no subclasses and no ducks
# not isinstance - we insist the real thing, no subclasses and no ducks
if type(cond) not in _condition_types:
raise UnsupportedLanguageConstruct(
None, node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
', '.join(_.__name__ for _ in _condition_types), type(cond).__name__))
None, node,
"`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
', '.join(_.__name__ for _ in _condition_types),
type(cond).__name__))
if cond:
self.visit_compound_statement(node.body)
else:
@@ -662,10 +683,14 @@ class CodeGenerator(ast.NodeVisitor):
return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None
else:
cond = _unwrap_if_constexpr(cond)
if type(cond) not in _condition_types: # not isinstance - we insist the real thing, no subclasses and no ducks
# not isinstance - we insist the real thing, no subclasses and no ducks
if type(cond) not in _condition_types:
raise UnsupportedLanguageConstruct(
None, node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
', '.join(_.__name__ for _ in _condition_types), type(cond).__name__))
None, node,
"`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
', '.join(_.__name__ for _ in _condition_types),
type(cond).__name__))
if cond:
return self.visit(node.body)
else:
@@ -687,8 +712,10 @@ class CodeGenerator(ast.NodeVisitor):
return constexpr(lhs_value is not rhs_value)
method_name = self._method_name_for_comp_op.get(type(node.ops[0]))
if method_name is None:
raise UnsupportedLanguageConstruct(None, node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__))
raise UnsupportedLanguageConstruct(
None, node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__))
return self._apply_binary_method(method_name, lhs, rhs)
_method_name_for_comp_op: Dict[Type[ast.cmpop], str] = {
ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__'
}
@@ -697,11 +724,15 @@ class CodeGenerator(ast.NodeVisitor):
op = self.visit(node.operand)
fn = self._method_name_for_unary_op.get(type(node.op))
if fn is None:
raise UnsupportedLanguageConstruct(None, node, "AST unary operator '{}' is not (currently) implemented.".format(node.op.__name__))
raise UnsupportedLanguageConstruct(
None, node, "AST unary operator '{}' is not (currently) implemented.".format(node.op.__name__))
if _is_triton_tensor(op):
return getattr(op, fn)(_builder=self.builder)
return getattr(op, fn)()
_method_name_for_unary_op: Dict[Type[ast.unaryop], str] = {ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__'}
_method_name_for_unary_op: Dict[Type[ast.unaryop], str] = {
ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__'
}
def visit_While(self, node):
with enter_sub_region(self) as sr:
@@ -796,9 +827,7 @@ class CodeGenerator(ast.NodeVisitor):
iter_args = [self.visit(arg) for arg in node.iter.args]
if IteratorClass == language.static_range:
iterator = IteratorClass(*iter_args)
static_range = range(iterator.start.value,
iterator.end.value,
iterator.step.value)
static_range = range(iterator.start.value, iterator.end.value, iterator.step.value)
for i in static_range:
self.lscope[node.target.id] = constexpr(i)
self.visit_compound_statement(node.body)
@@ -935,8 +964,7 @@ class CodeGenerator(ast.NodeVisitor):
def call_JitFunction(self, fn: JITFunction, args, kwargs):
args = inspect.getcallargs(fn.fn, *args, **kwargs)
args = [args[name] for name in fn.arg_names]
args = [arg if _is_triton_tensor(arg)
else constexpr(arg) for arg in args]
args = [arg if _is_triton_tensor(arg) else constexpr(arg) for arg in args]
# generate function def
attributes = dict()
constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)]
@@ -954,8 +982,9 @@ class CodeGenerator(ast.NodeVisitor):
debug = self.debug if fn.debug is None else fn.debug
file_name, begin_line = _get_fn_file_line(fn)
generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module,
function_name=fn_name, function_types=self.function_ret_types, debug=debug, noinline=fn.noinline,
file_name=file_name, begin_line=begin_line, target=self.builder.target)
function_name=fn_name, function_types=self.function_ret_types, debug=debug,
noinline=fn.noinline, file_name=file_name, begin_line=begin_line,
target=self.builder.target)
generator.visit(fn.parse())
callee_ret_type = generator.last_ret_type
self.function_ret_types[fn_name] = callee_ret_type
@@ -983,7 +1012,7 @@ class CodeGenerator(ast.NodeVisitor):
kws = dict(self.visit(keyword) for keyword in node.keywords)
args = [self.visit(arg) for arg in node.args]
if fn is language.core.device_assert: # TODO: this should not be so hardcoded
if fn is language.core.device_assert: # TODO: this should not be so hardcoded
if not self.debug:
return
if isinstance(fn, JITFunction):
@@ -1004,16 +1033,21 @@ class CodeGenerator(ast.NodeVisitor):
def visit_BoolOp(self, node: ast.BoolOp):
if len(node.values) != 2:
raise UnsupportedLanguageConstruct(None, node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.")
raise UnsupportedLanguageConstruct(
None, node,
"chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.")
lhs = self.visit(node.values[0])
rhs = self.visit(node.values[1])
method_name = self._method_name_for_bool_op.get(type(node.op))
if method_name is None:
raise UnsupportedLanguageConstruct(None, node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__))
raise UnsupportedLanguageConstruct(
None, node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__))
return self._apply_binary_method(method_name, lhs, rhs)
_method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'}
if sys.version_info < (3, 8):
def visit_NameConstant(self, node):
return constexpr(node.value)
@@ -1046,7 +1080,9 @@ class CodeGenerator(ast.NodeVisitor):
evaluated = self.visit(value.value)
if not _is_constexpr(evaluated):
raise UnsupportedLanguageConstruct(
None, node, "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type " + str(type(evaluated)))
None, node,
"Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type "
+ str(type(evaluated)))
values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value)
else:
raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value)))
@@ -1088,7 +1124,9 @@ class CodeGenerator(ast.NodeVisitor):
passed = _unwrap_if_constexpr(self.visit(node.args[0]))
if not isinstance(passed, bool):
raise NotImplementedError("Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values")
raise NotImplementedError(
"Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values"
)
if not passed:
if arg_count == 1:
message = ""
@@ -1175,10 +1213,9 @@ def ast_to_ttir(fn, signature, specialization, constants, debug, target):
file_name, begin_line = _get_fn_file_line(fn)
prototype = language.function_type([], arg_types)
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants,
function_name=function_name, attributes=new_attrs,
is_kernel=True, debug=debug, file_name=file_name, begin_line=begin_line,
target=target)
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name,
attributes=new_attrs, is_kernel=True, debug=debug, file_name=file_name,
begin_line=begin_line, target=target)
try:
generator.visit(fn.parse())
except CompilationError as e:

View File

@@ -11,10 +11,8 @@ from typing import Any
from dataclasses import dataclass
from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs,
compile_ptx_to_cubin, get_env_vars, get_num_warps,
get_shared_memory_size, ir, runtime,
translate_llvmir_to_ptx,
from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs, compile_ptx_to_cubin, get_env_vars,
get_num_warps, get_shared_memory_size, ir, runtime, translate_llvmir_to_ptx,
translate_triton_gpu_to_llvmir)
from ..common.backend import get_backend, get_cuda_version_key, path_to_ptxas
from ..common.build import is_hip
@@ -23,13 +21,11 @@ from ..common.build import is_hip
from ..runtime.autotuner import OutOfResources
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
from ..runtime.driver import driver
from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device,
get_device_capability)
from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device, get_device_capability)
from ..tools.disasm import get_sass
from .code_generator import ast_to_ttir
from .make_launcher import make_stub
from .utils import (InfoFromBackendForTensorMap, TensorMapManager,
get_ids_of_tensormaps, parse_tma_info)
from .utils import (InfoFromBackendForTensorMap, TensorMapManager, get_ids_of_tensormaps, parse_tma_info)
@dataclass
@@ -44,6 +40,7 @@ def _is_cuda(target):
class LazyDict(dict):
def __getitem__(self, key):
val = dict.__getitem__(self, key)
if callable(val):
@@ -94,8 +91,8 @@ def ttir_to_ttgir(mod, num_warps, num_ctas, target):
return mod
def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target,
cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue):
def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization,
enable_persistent, optimize_epilogue):
is_cuda = _is_cuda(target)
if is_cuda:
capability = target.capability
@@ -173,6 +170,7 @@ def ttgir_to_llir(mod, extern_libs, target, tma_infos):
# PTX translation
@functools.lru_cache()
def ptx_get_version(cuda_version) -> int:
'''
@@ -253,7 +251,8 @@ def make_hash(fn, target, env_vars, device_backend, **kwargs):
enable_persistent = kwargs.get("enable_persistent", False)
debug = kwargs.get("debug", False)
# Get unique key for the compiled code
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1), sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8))
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1),
sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8))
configs_key = [get_conf_key(conf) for conf in configs]
env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())]
key = f"{fn.cache_key}-{version_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}"
@@ -299,12 +298,14 @@ else:
def _get_jsonable_constants(constants):
def _is_jsonable(x):
try:
json.dumps(x)
return True
except (TypeError, OverflowError):
return False
serialized_constants = {}
for constant in constants:
if _is_jsonable(constants[constant]):
@@ -319,7 +320,9 @@ def parse_mlir_module(path, context):
return module
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"], defaults=[set(), set(), set(), set()])
instance_descriptor = namedtuple("instance_descriptor",
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"],
defaults=[set(), set(), set(), set()])
def get_cuda_capability(capability):
@@ -355,10 +358,8 @@ def get_arch_default_num_stages(device_type, capability=None):
def add_cuda_stages(target, extern_libs, stages):
stages["ptx"] = (lambda path: Path(path).read_text(),
lambda src: llir_to_ptx(src, target))
stages["cubin"] = (lambda path: Path(path).read_bytes(),
lambda src: ptx_to_cubin(src, target))
stages["ptx"] = (lambda path: Path(path).read_text(), lambda src: llir_to_ptx(src, target))
stages["cubin"] = (lambda path: Path(path).read_bytes(), lambda src: ptx_to_cubin(src, target))
def compile(fn, **kwargs):
@@ -401,7 +402,8 @@ def compile(fn, **kwargs):
# build architecture descriptor
if device_type == "cuda":
_device_backend = get_backend(device_type)
target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps, enable_fp_fusion=enable_fp_fusion)
target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps,
enable_fp_fusion=enable_fp_fusion)
else:
_device_backend = get_backend(device_type)
assert _device_backend
@@ -409,11 +411,12 @@ def compile(fn, **kwargs):
# build compilation stages
stages = dict()
stages["ast"] = (lambda path: fn, None)
stages["ttir"] = (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target))
stages["ttir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttir(
ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target))
if is_cuda:
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
stages["ttgir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttgir(
ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info,
enable_warp_specialization, enable_persistent, optimize_epilogue))
stages["llir"] = (lambda path: Path(path).read_text(),
lambda src: ttgir_to_llir(src, extern_libs, target, tma_infos))
add_cuda_stages(target, extern_libs, stages)
@@ -451,7 +454,8 @@ def compile(fn, **kwargs):
if ir_name == 'ttgir':
num_warps_matches = re.findall(ttgir_num_warps_pattern, src)
assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps"
assert "num_warps" not in kwargs or int(num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile"
assert "num_warps" not in kwargs or int(
num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile"
num_warps = int(num_warps_matches[0])
param_tys = [convert_type_repr(ty) for ty in types]
signature = {k: v for k, v in enumerate(param_tys)}
@@ -461,8 +465,10 @@ def compile(fn, **kwargs):
fn_cache_manager = get_cache_manager(make_hash(fn, target, get_env_vars(), _device_backend, **kwargs))
# managers used to dump and override IR for debugging
enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1"
fn_override_manager = get_override_manager(make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True))
fn_dump_manager = get_dump_manager(make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True))
fn_override_manager = get_override_manager(
make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True))
fn_dump_manager = get_dump_manager(
make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True))
# determine name and extension type of provided function
if isinstance(fn, JITFunction):
@@ -475,9 +481,7 @@ def compile(fn, **kwargs):
metadata_filename = f"{name}.json"
# The group is addressed by the metadata
metadata_group = fn_cache_manager.get_group(
metadata_filename
) or {}
metadata_group = fn_cache_manager.get_group(metadata_filename) or {}
metadata_path = metadata_group.get(metadata_filename)
@@ -485,17 +489,18 @@ def compile(fn, **kwargs):
with open(metadata_path) as f:
metadata = json.load(f)
if 'tensormaps_info' in metadata:
metadata['tensormaps_info'] = [
InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']]
metadata['tensormaps_info'] = [InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']]
else:
metadata = {"num_warps": num_warps,
"num_ctas": num_ctas,
"num_stages": num_stages,
"enable_warp_specialization": enable_warp_specialization,
"enable_persistent": enable_persistent,
"constants": _get_jsonable_constants(constants),
"debug": debug,
"target": target, }
metadata = {
"num_warps": num_warps,
"num_ctas": num_ctas,
"num_stages": num_stages,
"enable_warp_specialization": enable_warp_specialization,
"enable_persistent": enable_persistent,
"constants": _get_jsonable_constants(constants),
"debug": debug,
"target": target,
}
metadata.update(get_env_vars())
if ext == "ptx":
assert "shared" in kwargs, "ptx compilation must provide shared memory size"
@@ -567,10 +572,7 @@ def compile(fn, **kwargs):
ids_of_folded_args = tuple([int(k) for k in configs[0].ids_of_folded_args]) if isinstance(fn, JITFunction) else ()
if "clusterDims" not in metadata:
metadata["clusterDims"] = [
cluster_info.clusterDimX,
cluster_info.clusterDimY,
cluster_info.clusterDimZ]
metadata["clusterDims"] = [cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ]
if len(tma_infos) > 0:
metadata["tensormaps_info"] = parse_tma_info(tma_infos, ids_of_folded_args)
@@ -584,7 +586,10 @@ def compile(fn, **kwargs):
fn.tensormaps_info = metadata["tensormaps_info"]
ids_of_const_exprs = tuple(fn.constexprs) if isinstance(fn, JITFunction) else ()
ids = {"ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs": ids_of_const_exprs}
ids = {
"ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs":
ids_of_const_exprs
}
# cache manager
if is_cuda:
so_path = make_stub(name, signature, constants, ids, enable_warp_specialization=enable_warp_specialization)
@@ -592,7 +597,8 @@ def compile(fn, **kwargs):
so_path = _device_backend.make_launcher_stub(name, signature, constants, ids)
# write-back metadata, if it didn't come from the cache
if metadata_path is None:
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, binary=False)
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename,
binary=False)
fn_cache_manager.put_group(metadata_filename, metadata_group)
# return handle to compiled kernel
@@ -640,10 +646,7 @@ class CompiledKernel:
if self.device_type in ["cuda"]:
device = get_current_device()
bin_path = {
driver.HIP: "hsaco_path",
driver.CUDA: "cubin"
}[driver.backend]
bin_path = {driver.HIP: "hsaco_path", driver.CUDA: "cubin"}[driver.backend]
max_shared = driver.utils.get_device_properties(device)["max_shared_mem"]
fn_load_binary = driver.utils.load_binary
else:
@@ -691,4 +694,5 @@ class CompiledKernel:
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.num_ctas, self.clusterDims[0],
self.clusterDims[1], self.clusterDims[2], self.shared, stream, self.cu_function,
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand)
return runner

View File

@@ -40,6 +40,7 @@ def make_stub(name, signature, constants, ids, **kwargs):
else:
return cache_path
# ----- source code generation --------
@@ -100,7 +101,10 @@ def generate_launcher(constants, signature, ids):
# generate glue code
folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']]
params = [i for i in signature.keys() if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs)]
params = [
i for i in signature.keys()
if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs)
]
src = f"""
#include \"cuda.h\"
#include <stdbool.h>

View File

@@ -158,19 +158,21 @@ class InfoFromBackendForTensorMap:
# dtype:cuda.CUtensorMapDataType | int
def bytes_from_type(self, dtype):
return {driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT8"]: 1,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT16"]: 2,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT32"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT32"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT64"]: 8,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT64"]: 8,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT16"]: 2,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT64"]: 8,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_BFLOAT16"]: 2,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ"]: 4}[dtype]
return {
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT8"]: 1,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT16"]: 2,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT32"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT32"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT64"]: 8,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT64"]: 8,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT16"]: 2,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT64"]: 8,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_BFLOAT16"]: 2,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32"]: 4,
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ"]: 4
}[dtype]
def getTensorMapDataType(self):
return self.tensorDataType
@@ -259,22 +261,29 @@ class InfoFromBackendForTensorMap:
self.getInterleave(),
self.getSwizzle(),
self.getL2Promotion(),
self.getOobFill()
self.getOobFill(),
)
# make hashable to use as partial key in cache
def __hash__(self):
return hash((self.ids_of_folded_args, self.globalAddressArgIdx, tuple(self.globalDimsArgIdx), tuple(self.globalStridesArgIdx), self.tensorDataType,
self.tensorRank, tuple(self.boxDims), tuple(self.elementStrides), self.interleave, self.swizzle, self.l2Promotion, self.oobFill))
return hash((self.ids_of_folded_args, self.globalAddressArgIdx, tuple(self.globalDimsArgIdx),
tuple(self.globalStridesArgIdx), self.tensorDataType, self.tensorRank, tuple(self.boxDims),
tuple(self.elementStrides), self.interleave, self.swizzle, self.l2Promotion, self.oobFill))
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
return (self.ids_of_folded_args, self.globalAddressArgIdx, self.globalDimsArgIdx, self.globalStridesArgIdx, self.tensorDataType, self.tensorRank, self.boxDims, self.elementStrides, self.interleave, self.swizzle, self.l2Promotion, self.oobFill) == (
other.ids_of_folded_args, other.globalAddressArgIdx, other.globalDimsArgIdx, other.globalStridesArgIdx, other.tensorDataType, other.tensorRank, other.boxDims, other.elementStrides, other.interleave, other.swizzle, other.l2Promotion, other.oobFill)
return (self.ids_of_folded_args, self.globalAddressArgIdx, self.globalDimsArgIdx, self.globalStridesArgIdx,
self.tensorDataType, self.tensorRank, self.boxDims, self.elementStrides, self.interleave, self.swizzle,
self.l2Promotion,
self.oobFill) == (other.ids_of_folded_args, other.globalAddressArgIdx, other.globalDimsArgIdx,
other.globalStridesArgIdx, other.tensorDataType, other.tensorRank, other.boxDims,
other.elementStrides, other.interleave, other.swizzle, other.l2Promotion,
other.oobFill)
class TensorMapManager:
def __init__(self):
self.tensormaps_device = {}
@@ -286,8 +295,7 @@ class TensorMapManager:
t_tensormap = e.tensormap(args)
TENSORMAP_SIZE_IN_BYTES = 128
t_tensormap_device = driver.utils.cuMemAlloc(TENSORMAP_SIZE_IN_BYTES)
driver.utils.cuMemcpyHtoD(
t_tensormap_device, t_tensormap, TENSORMAP_SIZE_IN_BYTES)
driver.utils.cuMemcpyHtoD(t_tensormap_device, t_tensormap, TENSORMAP_SIZE_IN_BYTES)
self.tensormaps_device[key] = t_tensormap_device
return int(self.tensormaps_device[key])

View File

@@ -109,7 +109,6 @@ from .random import (
uint32_to_uniform_float,
)
__all__ = [
"TRITON_MAX_TENSOR_NUMEL",
"abs",

View File

@@ -22,10 +22,8 @@ def builtin(fn: T) -> T:
@wraps(fn)
def wrapper(*args, **kwargs):
if "_builder" not in kwargs or kwargs["_builder"] is None:
raise ValueError(
"Did you forget to add @triton.jit ? "
"(`_builder` argument must be provided outside of JIT functions.)"
)
raise ValueError("Did you forget to add @triton.jit ? "
"(`_builder` argument must be provided outside of JIT functions.)")
return fn(*args, **kwargs)
setattr(wrapper, TRITON_BUILTIN, True)
@@ -54,7 +52,7 @@ def _to_tensor(x, builder):
else:
raise RuntimeError(f'Nonrepresentable integer {x}.')
elif isinstance(x, float):
min_float32 = 2 ** -126
min_float32 = 2**-126
max_float32 = (2 - 2**-23) * 2**127
abs_x = __builtins__['abs'](x)
if abs_x == float("inf") or\
@@ -229,7 +227,7 @@ class dtype:
return not self.__eq__(other)
def __hash__(self):
return hash((self.name,))
return hash((self.name, ))
@property
def scalar(self):
@@ -279,6 +277,7 @@ class dtype:
class pointer_type(dtype):
def __init__(self, element_ty: dtype, address_space: int = 1):
if not isinstance(element_ty, dtype):
raise TypeError('element_ty is a {type(element_ty).__name__}.')
@@ -313,6 +312,7 @@ class pointer_type(dtype):
class block_type(dtype):
def __init__(self, element_ty: dtype, shape: List):
self.element_ty = element_ty
@@ -363,6 +363,7 @@ class block_type(dtype):
class function_type(dtype):
def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None:
self.ret_types = ret_types
self.param_types = param_types
@@ -511,7 +512,7 @@ class constexpr:
return constexpr(~self.value)
def __pow__(self, other):
return constexpr(self.value ** other.value)
return constexpr(self.value**other.value)
def __rshift__(self, other):
return constexpr(self.value >> other.value)
@@ -527,6 +528,7 @@ class constexpr:
class tensor:
def __init__(self, handle, type: dtype):
# IR handle
self.handle = handle
@@ -993,6 +995,7 @@ def expand_dims(input, axis, _builder=None):
ret = semantic.expand_dims(ret, a, _builder)
return ret
# -----------------------
# Linear Algebra
# -----------------------
@@ -1141,6 +1144,7 @@ def advance(base: tensor, offsets, _builder=None):
"""
return semantic.advance(base, offsets, _builder)
# -----------------------
# Atomic Memory Operations
# -----------------------
@@ -1253,6 +1257,7 @@ def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _builder=None):
# Conditioning
# -----------------------
@builtin
def where(condition, x, y, _builder=None):
"""
@@ -1280,6 +1285,7 @@ def where(condition, x, y, _builder=None):
# Math
# -----------------------
@builtin
def umulhi(x, y, _builder=None):
"""
@@ -1373,6 +1379,7 @@ def abs(x, _builder=None):
# Reductions
# -----------------------
def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]:
def _decorator(func: T) -> T:
@@ -1411,8 +1418,7 @@ def reduce(input, axis, combine_fn, _builder=None, _generator=None):
"""
if isinstance(input, tensor):
return reduce((input,), axis, combine_fn,
_builder=_builder, _generator=_generator)[0]
return reduce((input, ), axis, combine_fn, _builder=_builder, _generator=_generator)[0]
def make_combine_region(reduce_op):
in_scalar_tys = [t.type.scalar for t in input]
@@ -1422,14 +1428,14 @@ def reduce(input, axis, combine_fn, _builder=None, _generator=None):
with _insertion_guard(_builder):
param_types = [ty.to_ir(_builder) for ty in prototype.param_types]
block = _builder.create_block_with_parent(region, param_types)
args = [tensor(block.arg(i), ty)
for i, ty in enumerate(prototype.param_types)]
args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)]
results = _generator.call_JitFunction(combine_fn, args, kwargs={})
if isinstance(results, tensor):
handles = [results.handle]
else:
handles = [r.handle for r in results]
_builder.create_reduce_ret(*handles)
if axis is not None:
axis = _constexpr_to_value(axis)
return semantic.reduction(input, axis, make_combine_region, _builder)
@@ -1459,8 +1465,7 @@ def _reduce_with_indices(input, axis, combine_fn, _builder=None, _generator=None
index = expand_dims(index, axes_to_expand, _builder=_builder)
index = broadcast_to(index, input.shape, _builder=_builder)
rvalue, rindices = reduce((input, index), axis, combine_fn,
_builder=_builder, _generator=_generator)
rvalue, rindices = reduce((input, index), axis, combine_fn, _builder=_builder, _generator=_generator)
return rvalue, rindices
@@ -1468,6 +1473,7 @@ def _reduce_with_indices(input, axis, combine_fn, _builder=None, _generator=None
# Scans
# -----------------------
def _add_scan_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]:
def _decorator(func: T) -> T:
@@ -1492,8 +1498,7 @@ def associative_scan(input, axis, combine_fn, _builder=None, _generator=None):
"""
if isinstance(input, tensor):
return associative_scan((input,), axis, combine_fn,
_builder=_builder, _generator=_generator)[0]
return associative_scan((input, ), axis, combine_fn, _builder=_builder, _generator=_generator)[0]
def make_combine_region(scan_op):
in_scalar_tys = [t.type.scalar for t in input]
@@ -1503,17 +1508,18 @@ def associative_scan(input, axis, combine_fn, _builder=None, _generator=None):
with _insertion_guard(_builder):
param_types = [ty.to_ir(_builder) for ty in prototype.param_types]
block = _builder.create_block_with_parent(region, param_types)
args = [tensor(block.arg(i), ty)
for i, ty in enumerate(prototype.param_types)]
args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)]
results = _generator.call_JitFunction(combine_fn, args, kwargs={})
if isinstance(results, tensor):
handles = [results.handle]
else:
handles = [r.handle for r in results]
_builder.create_scan_ret(*handles)
axis = _constexpr_to_value(axis)
return semantic.associative_scan(input, axis, make_combine_region, _builder)
# -----------------------
# Compiler Hint Ops
# -----------------------
@@ -1576,6 +1582,8 @@ def max_constancy(input, values, _builder=None):
raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
values = [x.value for x in values]
return semantic.max_constancy(input, values)
# -----------------------
# Debugging functions
# -----------------------
@@ -1715,12 +1723,12 @@ def inline_asm_elementwise(asm: str, constraints: str, args: list, dtype, is_pur
broadcast_arg = dispatch_args[0]
# Get the broadcast shape over all the arguments
for i, item in enumerate(dispatch_args):
_, broadcast_arg = semantic.binary_op_type_checking_impl(
item, broadcast_arg, _builder, arithmetic_check=False)
_, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder,
arithmetic_check=False)
# Change the shape of each argument based on the broadcast shape
for i in range(len(dispatch_args)):
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
dispatch_args[i], broadcast_arg, _builder, arithmetic_check=False)
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder,
arithmetic_check=False)
ret_shape = broadcast_arg.shape
res_ty = block_type(dtype, ret_shape)
call = _builder.create_inline_asm(asm, constraints, [t.handle for t in args], res_ty.to_ir(_builder), is_pure, pack)
@@ -1733,7 +1741,6 @@ def inline_asm_elementwise(asm: str, constraints: str, args: list, dtype, is_pur
class static_range:
"""
Iterator that counts upward forever.
@@ -1777,7 +1784,9 @@ class static_range:
# Extern functions
# -----------------------
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, is_pure: bool, _builder=None):
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple,
is_pure: bool, _builder=None):
'''
Dispatch a function to a library
:param func: the function to dispatch
@@ -1819,7 +1828,8 @@ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dic
return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder), is_pure), ret_type)
def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, _builder=None):
def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool,
_builder=None):
'''
Dispatch an elementwise function to a library
:param lib_name: the name of the library
@@ -1848,12 +1858,12 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol
broadcast_arg = dispatch_args[0]
# Get the broadcast shape over all the arguments
for i, item in enumerate(dispatch_args):
_, broadcast_arg = semantic.binary_op_type_checking_impl(
item, broadcast_arg, _builder, arithmetic_check=arithmetic_check)
_, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder,
arithmetic_check=arithmetic_check)
# Change the shape of each argument based on the broadcast shape
for i in range(len(dispatch_args)):
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
dispatch_args[i], broadcast_arg, _builder, arithmetic_check=arithmetic_check)
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder,
arithmetic_check=arithmetic_check)
if not all_scalar:
ret_shape = broadcast_arg.shape
func = getattr(_builder, "create_extern_elementwise")

View File

@@ -3,16 +3,14 @@ from .. import core
@core.extern
def globaltimer(_builder=None):
return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [],
dtype=core.int64, is_pure=False,
pack=1, _builder=_builder)
return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], dtype=core.int64, is_pure=False, pack=1,
_builder=_builder)
@core.extern
def smid(_builder=None):
return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [],
dtype=core.int32, is_pure=True,
pack=1, _builder=_builder)
return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1,
_builder=_builder)
@core.builtin

File diff suppressed because it is too large Load Diff

View File

@@ -91,6 +91,7 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
# two_to_the_minus_32: tl.constexpr = 2.328306e-10
# return x * two_to_the_minus_32
@jit
def uint32_to_uniform_float(x):
"""
@@ -134,6 +135,7 @@ def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
u4 = uint32_to_uniform_float(i4)
return u1, u2, u3, u4
# -------------------
# randn
# -------------------

View File

@@ -16,10 +16,12 @@ def _is_cuda(target):
from ..compiler.compiler import CudaTargetDescriptor
return isinstance(target, CudaTargetDescriptor)
# Create custom exception that prints message "hello"
class IncompatibleTypeErrorImpl(Exception):
def __init__(self, type_a, type_b):
self.type_a = type_a
self.type_b = type_b
@@ -31,6 +33,7 @@ class IncompatibleTypeErrorImpl(Exception):
# Programming Model
# ===----------------------------------------------------------------------===##
def program_id(axis: int, builder: ir.builder) -> tl.tensor:
if axis not in (0, 1, 2):
raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}")
@@ -42,6 +45,7 @@ def num_programs(axis: int, builder: ir.builder) -> tl.tensor:
raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}")
return tl.tensor(builder.create_get_num_programs(axis), tl.int32)
# ===----------------------------------------------------------------------===//
# Implicit Casting Utilities
# ===----------------------------------------------------------------------===//
@@ -92,10 +96,12 @@ def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> t
# 5 ) both operands are integer and undergo
# integer promotion
if div_or_mod and a_ty.int_signedness != b_ty.int_signedness:
raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + " because they have different signedness;"
raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() +
" because they have different signedness;"
"this is unlikely to result in a useful answer. Cast them to the same signedness.")
return integer_promote_impl(a_ty, b_ty)
# ===----------------------------------------------------------------------===//
# Binary Operators
# ===----------------------------------------------------------------------===//
@@ -113,12 +119,9 @@ def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -
raise IncompatibleTypeErrorImpl(type_a, type_b)
def binary_op_type_checking_impl(lhs: tl.tensor,
rhs: tl.tensor,
builder: ir.builder,
allow_lhs_ptr=False, allow_rhs_ptr=False,
arithmetic_check=True, div_or_mod=False
) -> Tuple[tl.tensor, tl.tensor]:
def binary_op_type_checking_impl(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder, allow_lhs_ptr=False,
allow_rhs_ptr=False, arithmetic_check=True,
div_or_mod=False) -> Tuple[tl.tensor, tl.tensor]:
# implicit broadcasting
lhs, rhs = broadcast_impl_value(lhs, rhs, builder)
# implicit typecasting
@@ -133,9 +136,7 @@ def binary_op_type_checking_impl(lhs: tl.tensor,
return lhs, rhs
def add(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def add(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder, True, True)
input_scalar_ty = input.type.scalar
other_scalar_ty = other.type.scalar
@@ -159,15 +160,12 @@ def add(input: tl.tensor,
assert False
def sub(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def sub(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder, True, False)
scalar_ty = input.type.scalar
# ptr - offset
if scalar_ty.is_ptr():
return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle),
input.type)
return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle), input.type)
# float - float
if scalar_ty.is_floating():
return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type)
@@ -177,9 +175,7 @@ def sub(input: tl.tensor,
assert False
def mul(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def mul(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float * float
@@ -191,9 +187,7 @@ def mul(input: tl.tensor,
assert False
def truediv(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def truediv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
input_scalar_ty = input.type.scalar
other_scalar_ty = other.type.scalar
@@ -219,9 +213,7 @@ def truediv(input: tl.tensor,
return tl.tensor(builder.create_fdiv(input.handle, other.handle), input.type)
def floordiv(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def floordiv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
input_scalar_ty = input.type.scalar
other_scalar_ty = other.type.scalar
@@ -236,10 +228,7 @@ def floordiv(input: tl.tensor,
assert False
def fdiv(input: tl.tensor,
other: tl.tensor,
ieee_rounding: bool,
builder: ir.builder) -> tl.tensor:
def fdiv(input: tl.tensor, other: tl.tensor, ieee_rounding: bool, builder: ir.builder) -> tl.tensor:
input_scalar_ty = input.type.scalar
other_scalar_ty = other.type.scalar
if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating():
@@ -249,18 +238,14 @@ def fdiv(input: tl.tensor,
return tl.tensor(ret, input.type)
def mod(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def mod(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
scalar_ty = input.type.scalar
other_scalar_ty = other.type.scalar
# float % float
if scalar_ty.is_floating():
# input - input.div(other, rounding_mode="floor") * other
ret = sub(input, mul(floor(fdiv(input, other, False, builder), builder),
other, builder),
builder)
ret = sub(input, mul(floor(fdiv(input, other, False, builder), builder), other, builder), builder)
return ret
# % int
elif scalar_ty.is_int():
@@ -274,13 +259,13 @@ def mod(input: tl.tensor,
return tl.tensor(builder.create_urem(input.handle, other.handle), input.type)
assert False
##############
# bitwise ops
##############
def bitwise_op_type_checking_impl(input: tl.tensor,
other: tl.tensor,
def bitwise_op_type_checking_impl(input: tl.tensor, other: tl.tensor,
builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]:
input, other = binary_op_type_checking_impl(input, other, builder, False, False, False)
input_sca_ty = input.type.scalar
@@ -295,23 +280,17 @@ def bitwise_op_type_checking_impl(input: tl.tensor,
return input, other
def and_(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def and_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = bitwise_op_type_checking_impl(input, other, builder)
return tl.tensor(builder.create_and(input.handle, other.handle), input.type)
def or_(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def or_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = bitwise_op_type_checking_impl(input, other, builder)
return tl.tensor(builder.create_or(input.handle, other.handle), input.type)
def xor_(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def xor_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = bitwise_op_type_checking_impl(input, other, builder)
return tl.tensor(builder.create_xor(input.handle, other.handle), input.type)
@@ -338,26 +317,21 @@ def not_(input: tl.tensor, builder: ir.builder):
return invert(input, builder)
def lshr(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def lshr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = bitwise_op_type_checking_impl(input, other, builder)
return tl.tensor(builder.create_lshr(input.handle, other.handle), input.type)
def ashr(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def ashr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = bitwise_op_type_checking_impl(input, other, builder)
return tl.tensor(builder.create_ashr(input.handle, other.handle), input.type)
def shl(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def shl(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = bitwise_op_type_checking_impl(input, other, builder)
return tl.tensor(builder.create_shl(input.handle, other.handle), input.type)
# ===----------------------------------------------------------------------===//
# Unary Operators
# ===----------------------------------------------------------------------===//
@@ -367,8 +341,7 @@ def plus(input: tl.tensor) -> tl.tensor:
return input
def minus(input: tl.tensor,
builder: ir.builder) -> tl.tensor:
def minus(input: tl.tensor, builder: ir.builder) -> tl.tensor:
input_sca_ty = input.type.scalar
if input_sca_ty.is_ptr():
raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")")
@@ -376,8 +349,7 @@ def minus(input: tl.tensor,
return sub(_0, input, builder)
def invert(input: tl.tensor,
builder: tl.tensor) -> tl.tensor:
def invert(input: tl.tensor, builder: tl.tensor) -> tl.tensor:
input_sca_ty = input.type.scalar
if input_sca_ty.is_ptr() or input_sca_ty.is_floating():
raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")")
@@ -395,9 +367,7 @@ def _bool_like(v: tl.tensor) -> tl.block_type:
return tl.block_type(tl.int1, shape)
def greater_than(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def greater_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float > float
@@ -412,9 +382,7 @@ def greater_than(input: tl.tensor,
assert False
def greater_equal(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def greater_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float >= float
@@ -429,9 +397,7 @@ def greater_equal(input: tl.tensor,
assert False
def less_than(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def less_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float < float
@@ -446,9 +412,7 @@ def less_than(input: tl.tensor,
assert False
def less_equal(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def less_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float < float
@@ -463,9 +427,7 @@ def less_equal(input: tl.tensor,
assert False
def equal(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float == float
@@ -477,9 +439,7 @@ def equal(input: tl.tensor,
assert False
def not_equal(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
def not_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float == float
@@ -490,6 +450,7 @@ def not_equal(input: tl.tensor,
return tl.tensor(builder.create_icmpNE(input.handle, other.handle), _bool_like(input))
assert False
# ===----------------------------------------------------------------------===//
# Block Creation
# ===----------------------------------------------------------------------===//
@@ -532,6 +493,7 @@ def full(shape: List[int], value, dtype: tl.dtype, builder: ir.builder) -> tl.te
# Shape Manipulation
# ===----------------------------------------------------------------------===//
def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor:
assert not value.type.is_block(), "Cannot splat a block tensor"
if len(shape) == 0:
@@ -540,9 +502,7 @@ def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_splat(value.handle, shape), ret_ty)
def view(input: tl.tensor,
dst_shape: List[int],
builder: ir.builder) -> tl.tensor:
def view(input: tl.tensor, dst_shape: List[int], builder: ir.builder) -> tl.tensor:
# TODO: disable when TritonToTritonGPU handles views properly
# assert len(input.shape) == len(dst_shape)
@@ -555,9 +515,7 @@ def view(input: tl.tensor,
return tl.tensor(builder.create_view(input.handle, dst_shape), ret_ty)
def reshape(input: tl.tensor,
dst_shape: List[int],
builder: ir.builder) -> tl.tensor:
def reshape(input: tl.tensor, dst_shape: List[int], builder: ir.builder) -> tl.tensor:
raise ValueError("`reshape` is not supported yet. Please use `view` instead if applicable. "
"Note that view may reorder elements in an implementation- and context- dependent way.")
@@ -587,9 +545,7 @@ def trans(input: tl.tensor, builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_trans(input.handle), ret_type)
def broadcast_impl_shape(input: tl.tensor,
shape: List[int],
builder: ir.builder) -> tl.tensor:
def broadcast_impl_shape(input: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor:
if not input.type.is_block():
ret_ty = tl.block_type(input.type, shape)
return tl.tensor(builder.create_splat(input.handle, shape), ret_ty)
@@ -607,9 +563,7 @@ def broadcast_impl_shape(input: tl.tensor,
return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty)
def broadcast_impl_value(lhs: tl.tensor,
rhs: tl.tensor,
builder: ir.builder) -> tl.tensor:
def broadcast_impl_value(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor:
lhs_ty = lhs.type
rhs_ty = rhs.type
@@ -629,13 +583,15 @@ def broadcast_impl_value(lhs: tl.tensor,
if len(lhs_shape) < len(rhs_shape):
# Add new axes to lhs
for dim in range(len(lhs_shape), len(rhs_shape)):
lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0), tl.block_type(lhs_ty.scalar, [1] + lhs_shape))
lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0),
tl.block_type(lhs_ty.scalar, [1] + lhs_shape))
lhs_ty = lhs.type
lhs_shape = lhs_ty.get_block_shapes()
elif len(rhs_shape) < len(lhs_shape):
# Add new axes to rhs
for dim in range(len(rhs_shape), len(lhs_shape)):
rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0), tl.block_type(rhs_ty.scalar, [1] + rhs_shape))
rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0),
tl.block_type(rhs_ty.scalar, [1] + rhs_shape))
rhs_ty = rhs.type
rhs_shape = rhs_ty.get_block_shapes()
assert len(rhs_shape) == len(lhs_shape)
@@ -661,14 +617,13 @@ def broadcast_impl_value(lhs: tl.tensor,
# (scalar, scalar) => returns original blocks
return lhs, rhs
#######
# cast
#######
def bitcast(input: tl.tensor,
dst_ty: tl.dtype,
builder: ir.builder) -> tl.tensor:
def bitcast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor:
src_ty = input.type
if src_ty.is_block():
dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes())
@@ -684,13 +639,10 @@ def bitcast(input: tl.tensor,
if src_bits != dst_bits:
raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to "
"data-type of size " + str(dst_bits))
return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)),
dst_ty)
return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty)
def cast(input: tl.tensor,
dst_ty: tl.dtype,
builder: ir.builder) -> tl.tensor:
def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor:
src_ty = input.type
if isinstance(dst_ty, tl.constexpr):
dst_ty = dst_ty.value
@@ -709,8 +661,7 @@ def cast(input: tl.tensor,
# Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64
if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \
(src_sca_ty.is_floating() and dst_sca_ty.is_fp8()):
return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder)),
dst_ty)
return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty)
# bf16 <=> (not fp32)
if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \
@@ -724,9 +675,7 @@ def cast(input: tl.tensor,
dst_sca_ty.is_floating() and \
src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth
if truncate_fp:
return tl.tensor(builder.create_fp_trunc(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), dst_ty)
# Standard floating types' casting: extension
# fp32 => fp64
@@ -736,9 +685,7 @@ def cast(input: tl.tensor,
dst_sca_ty.is_floating() and \
src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth
if ext_fp:
return tl.tensor(builder.create_fp_ext(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), dst_ty)
# Casting between integer types
if src_sca_ty.is_int() and dst_sca_ty.is_int() and \
@@ -749,9 +696,7 @@ def cast(input: tl.tensor,
_0 = tl.tensor(builder.get_null_value(ty), input.dtype)
return not_equal(input, _0, builder)
else:
return tl.tensor(builder.create_int_cast(input.handle,
dst_ty.to_ir(builder), sign_extend),
dst_ty)
return tl.tensor(builder.create_int_cast(input.handle, dst_ty.to_ir(builder), sign_extend), dst_ty)
# Casting standard floating types to integer types
if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int():
@@ -760,35 +705,24 @@ def cast(input: tl.tensor,
_0 = tl.tensor(builder.get_null_value(ty), input.dtype)
return not_equal(input, _0, builder)
elif dst_sca_ty.is_int_signed():
return tl.tensor(builder.create_fp_to_si(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
return tl.tensor(builder.create_fp_to_si(input.handle, dst_ty.to_ir(builder)), dst_ty)
else:
return tl.tensor(builder.create_fp_to_ui(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
return tl.tensor(builder.create_fp_to_ui(input.handle, dst_ty.to_ir(builder)), dst_ty)
# Casting integer types to standard floating types
if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating():
if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed():
return tl.tensor(builder.create_ui_to_fp(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
return tl.tensor(builder.create_ui_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty)
else:
return tl.tensor(builder.create_si_to_fp(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
return tl.tensor(builder.create_si_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty)
# Casting pointer types to integer types
if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
bitwidth = dst_sca_ty.int_bitwidth
if bitwidth == 64:
return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)),
dst_ty)
return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty)
if bitwidth == 1:
return not_equal(cast(input, tl.int64, builder),
tl.tensor(builder.get_int64(0), tl.int64),
builder)
return not_equal(cast(input, tl.int64, builder), tl.tensor(builder.get_int64(0), tl.int64), builder)
# Casting integer types to pointer types
if src_sca_ty.is_int() and dst_sca_ty.is_ptr():
@@ -800,6 +734,7 @@ def cast(input: tl.tensor,
assert False, f'cannot cast {input} to {dst_ty}'
# ===----------------------------------------------------------------------===//
# Memory Operators
# ===----------------------------------------------------------------------===//
@@ -918,8 +853,8 @@ def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, evicti
boundary_check = _canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes())
# Build IR
return tl.tensor(builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction,
is_volatile), dst_ty)
return tl.tensor(
builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile), dst_ty)
def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder):
@@ -975,19 +910,13 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_
if not mask:
return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty)
else:
return tl.tensor(builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache,
eviction, is_volatile), dst_ty)
return tl.tensor(
builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction,
is_volatile), dst_ty)
def load(ptr: tl.tensor,
mask: Optional[tl.tensor],
other: Optional[tl.tensor],
boundary_check,
padding_option: str,
cache_modifier: str,
eviction_policy: str,
is_volatile: bool,
builder: ir.builder) -> tl.tensor:
def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], boundary_check, padding_option: str,
cache_modifier: str, eviction_policy: str, is_volatile: bool, builder: ir.builder) -> tl.tensor:
# Cache, eviction and padding options
cache = _str_to_load_cache_modifier(cache_modifier)
eviction = _str_to_eviction_policy(eviction_policy)
@@ -1012,7 +941,8 @@ def _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builde
if not val.type.is_block():
val = broadcast_impl_shape(val, block_shape, builder)
assert val.type.is_block(), "Value argument must be block type or a scalar"
assert block_shape == val.type.get_block_shapes(), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch"
assert block_shape == val.type.get_block_shapes(
), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch"
assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch"
elt_ty = ptr.type.element_ty.element_ty
@@ -1070,13 +1000,8 @@ def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder):
return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction), tl.void)
def store(ptr: tl.tensor,
val: tl.tensor,
mask: Optional[tl.tensor],
boundary_check,
cache_modifier: str,
eviction_policy: str,
builder: ir.builder) -> tl.tensor:
def store(ptr: tl.tensor, val: tl.tensor, mask: Optional[tl.tensor], boundary_check, cache_modifier: str,
eviction_policy: str, builder: ir.builder) -> tl.tensor:
# Cache and eviction options
cache = _str_to_store_cache_modifier(cache_modifier)
eviction = _str_to_eviction_policy(eviction_policy)
@@ -1094,12 +1019,7 @@ def store(ptr: tl.tensor,
#########
def atomic_cas(ptr: tl.tensor,
cmp: tl.tensor,
val: tl.tensor,
sem: str,
scope: str,
builder: ir.builder) -> tl.tensor:
def atomic_cas(ptr: tl.tensor, cmp: tl.tensor, val: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
sem = _str_to_sem(sem)
scope = _str_to_scope(scope)
element_ty = ptr.type.scalar.element_ty
@@ -1108,10 +1028,7 @@ def atomic_cas(ptr: tl.tensor,
return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type)
def atom_red_typechecking_impl(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
op: str,
def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, op: str,
builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]:
if not ptr.type.scalar.is_ptr():
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
@@ -1136,12 +1053,7 @@ def atom_red_typechecking_impl(ptr: tl.tensor,
return ptr, val, mask
def atomic_max(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
scope: str,
builder: ir.builder) -> tl.tensor:
def atomic_max(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder)
sem = _str_to_sem(sem)
scope = _str_to_scope(scope)
@@ -1149,21 +1061,11 @@ def atomic_max(ptr: tl.tensor,
# direct call to atomic_max for integers
if sca_ty.is_int():
if sca_ty.is_int_signed():
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX,
ptr.handle,
val.handle,
mask.handle,
sem,
scope),
val.type)
return tl.tensor(
builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
else:
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX,
ptr.handle,
val.handle,
mask.handle,
sem,
scope),
val.type)
return tl.tensor(
builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
# for float
# return atomic_smax(i_ptr, i_val) if val >= 0
# return atomic_umin(i_ptr, i_val) if val < 0
@@ -1177,18 +1079,17 @@ def atomic_max(ptr: tl.tensor,
i_ptr = bitcast(ptr, tl.pointer_type(itype, 1), builder)
pos = greater_equal(val, zero, builder)
neg = less_than(val, zero, builder)
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle, sem, scope), i_val.type)
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle, sem, scope), i_val.type)
pos_ret = tl.tensor(
builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle,
and_(mask, pos, builder).handle, sem, scope), i_val.type)
neg_ret = tl.tensor(
builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle,
and_(mask, neg, builder).handle, sem, scope), i_val.type)
ret = where(pos, pos_ret, neg_ret, builder)
return bitcast(ret, sca_ty, builder)
def atomic_min(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
scope: str,
builder: ir.builder) -> tl.tensor:
def atomic_min(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder)
sem = _str_to_sem(sem)
scope = _str_to_scope(scope)
@@ -1196,21 +1097,11 @@ def atomic_min(ptr: tl.tensor,
# direct call to atomic_min for integers
if sca_ty.is_int():
if sca_ty.is_int_signed():
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN,
ptr.handle,
val.handle,
mask.handle,
sem,
scope),
val.type)
return tl.tensor(
builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
else:
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN,
ptr.handle,
val.handle,
mask.handle,
sem,
scope),
val.type)
return tl.tensor(
builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
# for float
# return atomic_smin(i_ptr, i_val) if val >= 0
# return atomic_umax(i_ptr, i_val) if val < 0
@@ -1224,30 +1115,17 @@ def atomic_min(ptr: tl.tensor,
i_ptr = bitcast(ptr, tl.pointer_type(itype, 1), builder)
pos = greater_equal(val, zero, builder)
neg = less_than(val, zero, builder)
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN,
i_ptr.handle,
i_val.handle,
and_(mask, pos, builder).handle,
sem,
scope),
i_val.type)
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX,
i_ptr.handle,
i_val.handle,
and_(mask, neg, builder).handle,
sem,
scope),
i_val.type)
pos_ret = tl.tensor(
builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, i_ptr.handle, i_val.handle,
and_(mask, pos, builder).handle, sem, scope), i_val.type)
neg_ret = tl.tensor(
builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, i_ptr.handle, i_val.handle,
and_(mask, neg, builder).handle, sem, scope), i_val.type)
ret = where(pos, pos_ret, neg_ret, builder)
return bitcast(ret, sca_ty, builder)
def atomic_add(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
scope: str,
builder: ir.builder) -> tl.tensor:
def atomic_add(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder)
sem = _str_to_sem(sem)
scope = _str_to_scope(scope)
@@ -1256,52 +1134,38 @@ def atomic_add(ptr: tl.tensor,
return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
def atomic_and(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
scope: str,
builder: ir.builder) -> tl.tensor:
def atomic_and(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder)
sem = _str_to_sem(sem)
scope = _str_to_scope(scope)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope),
val.type)
def atomic_or(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
scope: str,
builder: ir.builder) -> tl.tensor:
def atomic_or(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder)
sem = _str_to_sem(sem)
scope = _str_to_scope(scope)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope),
val.type)
def atomic_xor(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
scope: str,
builder: ir.builder) -> tl.tensor:
def atomic_xor(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder)
sem = _str_to_sem(sem)
scope = _str_to_scope(scope)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope),
val.type)
def atomic_xchg(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
scope: str,
def atomic_xchg(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder)
sem = _str_to_sem(sem)
scope = _str_to_scope(scope)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope),
val.type)
# ===----------------------------------------------------------------------===//
# Linear Algebra
@@ -1321,13 +1185,9 @@ def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool:
return True
def dot(lhs: tl.tensor,
rhs: tl.tensor,
acc: tl.tensor,
allow_tf32: bool,
max_num_imprecise_acc: int,
out_dtype: tl.dtype,
builder: ir.builder) -> tl.tensor:
def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_num_imprecise_acc: int,
out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
def assert_dtypes_valid(lhs_dtype, rhs_dtype, target):
# Checks for non-cuda archs
if not _is_cuda(target):
@@ -1335,22 +1195,30 @@ def dot(lhs: tl.tensor,
return
# Checks for cuda arch
if target.capability < 90:
assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv(), "Dot op does not support fp8e4nv on CUDA arch < 90"
assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv(
), "Dot op does not support fp8e4nv on CUDA arch < 90"
if lhs_dtype.is_fp8() and rhs_dtype.is_fp8():
return
assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!"
else:
assert not lhs_dtype.is_fp8e4b15() and not rhs_dtype.is_fp8e4b15(), "Dot op does not support fp8e4b15 on CUDA arch >= 90"
assert not lhs_dtype.is_fp8e4b15x4() and not rhs_dtype.is_fp8e4b15x4(), "Dot op does not support fp8e4b15x4 on CUDA arch >= 90"
assert not lhs_dtype.is_fp8e4b15() and not rhs_dtype.is_fp8e4b15(
), "Dot op does not support fp8e4b15 on CUDA arch >= 90"
assert not lhs_dtype.is_fp8e4b15x4() and not rhs_dtype.is_fp8e4b15x4(
), "Dot op does not support fp8e4b15x4 on CUDA arch >= 90"
if lhs_dtype.is_int() or rhs_dtype.is_int():
assert lhs_dtype == rhs_dtype, f"Both operands must be same type. First operand ({lhs_dtype}) and second operand ({rhs_dtype})"
assert lhs_dtype.is_int8() or lhs_dtype.is_uint8(), f"Both operands must be either int8 or uint8. Operand type ({lhs_dtype})"
assert lhs_dtype.is_int8() or lhs_dtype.is_uint8(
), f"Both operands must be either int8 or uint8. Operand type ({lhs_dtype})"
elif lhs_dtype.is_fp8() or rhs_dtype.is_fp8():
assert lhs_dtype.is_fp8e4nv() or lhs_dtype.is_fp8e5(), f"Only supports fp8e4nv or fp8e5. First operand ({lhs_dtype})"
assert rhs_dtype.is_fp8e4nv() or rhs_dtype.is_fp8e5(), f"Only supports fp8e4nv or fp8e5. Second operand ({rhs_dtype})"
assert lhs_dtype.is_fp8e4nv() or lhs_dtype.is_fp8e5(
), f"Only supports fp8e4nv or fp8e5. First operand ({lhs_dtype})"
assert rhs_dtype.is_fp8e4nv() or rhs_dtype.is_fp8e5(
), f"Only supports fp8e4nv or fp8e5. Second operand ({rhs_dtype})"
else:
assert lhs_dtype.is_fp16() or lhs_dtype.is_bf16() or lhs_dtype.is_fp32() or lhs_dtype.is_int1(), f"Unsupported dtype {lhs_dtype}"
assert rhs_dtype.is_fp16() or rhs_dtype.is_bf16() or rhs_dtype.is_fp32() or rhs_dtype.is_int1(), f"Unsupported dtype {rhs_dtype}"
assert lhs_dtype.is_fp16() or lhs_dtype.is_bf16() or lhs_dtype.is_fp32() or lhs_dtype.is_int1(
), f"Unsupported dtype {lhs_dtype}"
assert rhs_dtype.is_fp16() or rhs_dtype.is_bf16() or rhs_dtype.is_fp32() or rhs_dtype.is_int1(
), f"Unsupported dtype {rhs_dtype}"
assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!"
assert lhs.type.is_block() and rhs.type.is_block()
@@ -1359,7 +1227,8 @@ def dot(lhs: tl.tensor,
assert len(lhs.shape) == 2, f"First input shape ({lhs.shape}) is not two dimensional!"
assert len(rhs.shape) == 2, f"Second input shape ({rhs.shape}) is not two dimensional!"
assert lhs.shape[1].value == rhs.shape[0].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[1].value}) must be equal to first index of second shape ({rhs.shape[0].value})"
assert lhs.shape[1].value == rhs.shape[
0].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[1].value}) must be equal to first index of second shape ({rhs.shape[0].value})"
assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \
and rhs.shape[1].value >= 16, \
f"All values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!"
@@ -1370,7 +1239,8 @@ def dot(lhs: tl.tensor,
_0 = builder.get_int32(0)
ret_scalar_ty = tl.int32
elif out_dtype.is_bf16():
raise ValueError("out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`")
raise ValueError(
"out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`")
elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16():
_0 = builder.get_fp32(0)
ret_scalar_ty = tl.float32
@@ -1391,10 +1261,10 @@ def dot(lhs: tl.tensor,
else:
_0 = builder.create_splat(builder.get_fp32(0), [M, N])
ret_ty = tl.block_type(ret_cast_scalar_ty, [M, N])
ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
ret_ty)
ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), ret_ty)
return cast(ret, ret_scalar_ty, builder)
if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty) and ret_scalar_ty.primitive_bitwidth < 32:
if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32,
ret_scalar_ty) and ret_scalar_ty.primitive_bitwidth < 32:
if lhs.type.scalar.is_int():
ret_dot_scalar_ty = tl.int32
_0 = builder.create_splat(builder.get_int32(0), [M, N])
@@ -1402,8 +1272,7 @@ def dot(lhs: tl.tensor,
ret_dot_scalar_ty = tl.float32
_0 = builder.create_splat(builder.get_fp32(0), [M, N])
ret_ty = tl.block_type(ret_dot_scalar_ty, [M, N])
ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
ret_ty)
ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), ret_ty)
return cast(ret, ret_scalar_ty, builder)
ret_ty = tl.block_type(ret_scalar_ty, [M, N])
if acc is None:
@@ -1413,23 +1282,21 @@ def dot(lhs: tl.tensor,
assert acc.type == ret_ty
# max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90
if not (_is_cuda(builder.target) and builder.target.capability == 90 and lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and ret_scalar_ty.is_fp32()):
if not (_is_cuda(builder.target) and builder.target.capability == 90 and lhs.dtype.is_fp8() and rhs.dtype.is_fp8()
and ret_scalar_ty.is_fp32()):
max_num_imprecise_acc = 0
if max_num_imprecise_acc is None:
max_num_imprecise_acc = 2**30
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, allow_tf32, max_num_imprecise_acc),
ret_ty)
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, allow_tf32, max_num_imprecise_acc), ret_ty)
# ===----------------------------------------------------------------------===//
# Indexing
# ===----------------------------------------------------------------------===//
def where(condition: tl.tensor,
x: tl.tensor,
y: tl.tensor,
builder: ir.builder) -> tl.tensor:
def where(condition: tl.tensor, x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor:
condition = cast(condition, tl.int1, builder)
if condition.type.is_block():
condition, x = broadcast_impl_value(condition, x, builder)
@@ -1442,14 +1309,13 @@ def where(condition: tl.tensor,
ret_ty = x.type
return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
# ===----------------------------------------------------------------------===//
# Reduction
# ===----------------------------------------------------------------------===
def reduction(
inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder
) -> Tuple[tl.tensor, ...]:
def reduction(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder) -> Tuple[tl.tensor, ...]:
if axis is None:
new_inputs = []
for i in range(len(inputs)):
@@ -1475,10 +1341,7 @@ def reduction(
region_builder_fn(reduce_op)
reduce_op.verify()
return tuple(
wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar)
for i in range(len(inputs))
)
return tuple(wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar) for i in range(len(inputs)))
# ===----------------------------------------------------------------------===
@@ -1486,9 +1349,8 @@ def reduction(
# ===----------------------------------------------------------------------===
def associative_scan(
inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder
) -> Tuple[tl.tensor, ...]:
def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn,
builder: ir.builder) -> Tuple[tl.tensor, ...]:
if len(inputs) != 1:
raise ValueError("Current implementation only support single tensor input")
shape = inputs[0].type.shape
@@ -1501,16 +1363,14 @@ def associative_scan(
region_builder_fn(scan_op)
scan_op.verify()
return tuple(
wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar)
for i in range(len(inputs))
)
return tuple(wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar) for i in range(len(inputs)))
# ===----------------------------------------------------------------------===
# Math
# ===----------------------------------------------------------------------===
def _check_dtype(dtypes: List[str]) -> T:
"""
We're following libdevice's convention to check accepted data types for math functions.
@@ -1519,7 +1379,9 @@ def _check_dtype(dtypes: List[str]) -> T:
We should let the users know that they are using and invoke explicit cast to convert
the data type to the supported one.
"""
def wrapper(fn):
@wraps(fn)
def check(*args, **kwargs):
# concatenate args and kwargs
@@ -1528,6 +1390,7 @@ def _check_dtype(dtypes: List[str]) -> T:
if arg.type.scalar.name not in dtypes:
raise ValueError(f"Expected dtype {dtypes} but got {arg.type.scalar.name}")
return fn(*args, **kwargs)
return check
return wrapper
@@ -1631,8 +1494,8 @@ def device_print(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.
def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: int, builder: ir.builder) -> tl.tensor:
cond_ty = cond.type
if not cond_ty.is_block():
cond_ty = tl.block_type(cond_ty.scalar, (1,))
cond = tl.tensor(builder.create_splat(cond.handle, (1,)), cond_ty)
cond_ty = tl.block_type(cond_ty.scalar, (1, ))
cond = tl.tensor(builder.create_splat(cond.handle, (1, )), cond_ty)
return tl.tensor(builder.create_assert(cond.handle, msg, file_name, func_name, lineno), tl.void)

View File

@@ -123,6 +123,7 @@ def maximum(x, y):
"""
return math.max(x, y)
# max and argmax
@@ -149,8 +150,7 @@ def _argmax_combine_tie_break_fast(value1, index1, value2, index2):
@jit
@core._add_reduction_docstr("maximum",
return_indices_arg="return_indices",
@core._add_reduction_docstr("maximum", return_indices_arg="return_indices",
tie_break_arg="return_indices_tie_break_left")
def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True):
input = core._promote_reduction_input(input)
@@ -175,6 +175,7 @@ def argmax(input, axis, tie_break_left=True):
(_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left)
return ret
# min and argmin
@@ -201,8 +202,7 @@ def _argmin_combine_tie_break_fast(value1, index1, value2, index2):
@jit
@core._add_reduction_docstr("minimum",
return_indices_arg="return_indices",
@core._add_reduction_docstr("minimum", return_indices_arg="return_indices",
tie_break_arg="return_indices_tie_break_left")
def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True):
input = core._promote_reduction_input(input)
@@ -222,8 +222,7 @@ def min(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr
@jit
@core._add_reduction_docstr("minimum index",
tie_break_arg="tie_break_left")
@core._add_reduction_docstr("minimum index", tie_break_arg="tie_break_left")
def argmin(input, axis, tie_break_left=True):
_, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left)
return ret
@@ -233,6 +232,7 @@ def argmin(input, axis, tie_break_left=True):
def _sum_combine(a, b):
return a + b
# sum
@@ -247,6 +247,7 @@ def sum(input, axis=None):
def _xor_combine(a, b):
return a ^ b
# xor sum
@@ -258,8 +259,8 @@ def xor_sum(input, axis=None, _builder=None, _generator=None):
raise ValueError("xor_sum only supported for integers")
input = core._promote_reduction_input(input, _builder=_builder)
return core.reduce(input, axis, _xor_combine,
_builder=_builder, _generator=_generator)
return core.reduce(input, axis, _xor_combine, _builder=_builder, _generator=_generator)
# cumsum
@@ -271,6 +272,7 @@ def cumsum(input, axis=0):
input = core._promote_reduction_input(input)
return core.associative_scan(input, axis, _sum_combine)
# cumprod

View File

@@ -17,15 +17,14 @@ from ... import language as tl
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
})
@jit
def _sdd_kernel(
A, B, C,
stride_za, stride_ha, stride_ma, stride_ak,
stride_zb, stride_hb, stride_bk, stride_nb,
stride_zc, stride_hc, stride_mc, stride_nc,
K, grid_offset, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
BLOCK: tl.constexpr, EVEN_K: tl.constexpr
):
def _sdd_kernel(A, B, C, #
stride_za, stride_ha, stride_ma, stride_ak, #
stride_zb, stride_hb, stride_bk, stride_nb, #
stride_zc, stride_hc, stride_mc, stride_nc, #
K, grid_offset, lut, #
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, #
BLOCK: tl.constexpr, EVEN_K: tl.constexpr #
):
# ------------ #
# - Prologue - #
# ------------ #
@@ -104,13 +103,13 @@ def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=
c = out
grid = [c.shape[1], 1, c.shape[0]]
_sdd_kernel[grid](
a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
Ka, 0, lut,
TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4,
num_warps=4,
a, b, c, #
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), #
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), #
c.stride(0), c.stride(1), c.stride(2), c.stride(3), #
Ka, 0, lut, #
TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4, #
num_warps=4 #
)
return c
@@ -120,6 +119,7 @@ def sdd_lut(layout, block, device):
lut = lut.contiguous()
return lut, None
# -----------------------------
# Dense = Sparse x Dense (DSD)
# This operation uses a look-up table that contains pre-computed pointer increments
@@ -128,15 +128,14 @@ def sdd_lut(layout, block, device):
@jit
def _dsd_kernel(
A, B, C,
stride_az, stride_ha, stride_am, stride_ak,
stride_zb, stride_hb, stride_bk, stride_bn,
stride_zc, stride_hc, stride_cm, stride_cn,
DS0, DS1, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr
):
def _dsd_kernel(A, B, C, #
stride_az, stride_ha, stride_am, stride_ak, #
stride_zb, stride_hb, stride_bk, stride_bn, #
stride_zc, stride_hc, stride_cm, stride_cn, #
DS0, DS1, lut, #
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr #
):
# ------------ #
# - Prologue - #
# ------------ #
@@ -229,13 +228,13 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=N
# compute output
grid = lambda meta: [cdiv(BS3, meta['TILE_N']), width, BS0]
_dsd_kernel[grid](
a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
BS3, AS1, lut,
TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
num_warps=4, GROUP_SIZE_M=4,
a, b, c, #
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), #
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), #
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), #
BS3, AS1, lut, #
TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4, #
num_warps=4, GROUP_SIZE_M=4 #
)
# exit()
return c
@@ -337,6 +336,7 @@ def dsd_lut(layout, block, step, trans, device):
# create locks
return lut, width
# -----------------------------
# Dense = Dense x Sparse (DDS)
# -----------------------------
@@ -346,6 +346,7 @@ def dsd_lut(layout, block, step, trans, device):
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
return dsd_matmul(b, a, not trans_b, not trans_a, not trans_c, spdims, block, lut, width, out=out)
##############
# MAIN API #
##############
@@ -356,10 +357,8 @@ class _matmul(torch.autograd.Function):
fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul}
@staticmethod
def forward(
ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block,
c_lut, c_width, da_lut, da_width, db_lut, db_width, out
):
def forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_width, da_lut, da_width, db_lut,
db_width, out):
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out)
# save for backward
ctx.save_for_backward(a, b)
@@ -385,15 +384,13 @@ class _matmul(torch.autograd.Function):
# gradients w.r.t. a
if ctx.needs_input_grad[0]:
mode_da = mode[1] + mode[0] + mode[2]
da = _matmul.fn[mode_da](
dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_width,
)
da = _matmul.fn[mode_da](dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block,
ctx.da_lut, ctx.da_width)
# gradients w.r.t. b
if ctx.needs_input_grad[1]:
mode_db = mode[2] + mode[1] + mode[0]
db = _matmul.fn[mode_db](
a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_width,
)
db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block,
ctx.db_lut, ctx.db_width)
dout = dc if ctx.has_out else None
return da, db, None, None, None, \
None, None, None, None, \
@@ -427,11 +424,9 @@ class matmul:
self.db_lut, self.db_width = sdd_lut(layout, block, device)
def __call__(self, a, b, out=None):
c = _matmul.apply(
a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block,
self.c_lut, self.c_width,
self.da_lut, self.da_width,
self.db_lut, self.db_width,
out
)
c = _matmul.apply(a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, #
self.c_lut, self.c_width, #
self.da_lut, self.da_width, #
self.db_lut, self.db_width, #
out)
return c

View File

@@ -18,14 +18,13 @@ def num_warps(n):
@jit
def _blocksparse_softmax_fwd(
Out, A, stride_xz, LUT,
R, extent, stride_zr, stride_hr, # relative attention
scale, is_causal,
ROW_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
IS_DENSE: tl.constexpr,
):
def _blocksparse_softmax_fwd(Out, A, stride_xz, LUT, #
R, extent, stride_zr, stride_hr, # relative attention
scale, is_causal, #
ROW_SIZE: tl.constexpr, #
BLOCK_SIZE: tl.constexpr, #
IS_DENSE: tl.constexpr #
):
h = tl.program_id(0)
m = tl.program_id(1)
z = tl.program_id(2)
@@ -73,18 +72,16 @@ def _blocksparse_softmax_fwd(
@jit
def _blocksparse_softmax_bwd(
DA, stride_zdx,
DOut, stride_zdout,
Out, stride_zout,
scale,
LUT,
DR, extent, stride_zr, stride_hr, stride_er,
is_causal,
ROW_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
IS_DENSE: tl.constexpr,
):
def _blocksparse_softmax_bwd(DA, stride_zdx, #
DOut, stride_zdout, #
Out, stride_zout, #
scale, #
LUT, #
DR, extent, stride_zr, stride_hr, stride_er, #
is_causal, #
ROW_SIZE: tl.constexpr, #
BLOCK_SIZE: tl.constexpr, #
IS_DENSE: tl.constexpr):
h = tl.program_id(0)
m = tl.program_id(1)
z = tl.program_id(2)
@@ -133,6 +130,7 @@ def _blocksparse_softmax_bwd(
class _softmax(torch.autograd.Function):
@staticmethod
def make_lut(layout, block, device):
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
@@ -151,10 +149,7 @@ class _softmax(torch.autograd.Function):
return lut, int(total_sizes.max())
@staticmethod
def forward(
ctx, a, scale, rel_logits, is_causal,
spdims, block, lut, maxlut, is_dense
):
def forward(ctx, a, scale, rel_logits, is_causal, spdims, block, lut, maxlut, is_dense):
if scale is not None and isinstance(scale, torch.Tensor):
assert scale.device.type == "cpu"
scale = scale.item()
@@ -165,14 +160,14 @@ class _softmax(torch.autograd.Function):
# enqueue kernel
out = torch.empty_like(a)
_blocksparse_softmax_fwd[grid](
out, a, a.stride(0), lut,
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn
scale,
is_causal,
BLOCK_SIZE=block,
ROW_SIZE=next_power_of_2(maxlut),
IS_DENSE=is_dense,
num_warps=num_warps(maxlut)
out, a, a.stride(0), lut, #
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn#
scale, #
is_causal, #
BLOCK_SIZE=block, #
ROW_SIZE=next_power_of_2(maxlut), #
IS_DENSE=is_dense, #
num_warps=num_warps(maxlut) #
)
# save to context
# ctx.mark_dirty(x)
@@ -201,28 +196,23 @@ class _softmax(torch.autograd.Function):
grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M)
da = torch.empty_like(dout)
_blocksparse_softmax_bwd[grid](
da, da.stride(0),
dout, dout.stride(0),
out, out.stride(0),
ctx.scale,
lut,
dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2],
ctx.is_causal,
BLOCK_SIZE=ctx.block,
ROW_SIZE=next_power_of_2(ctx.maxlut),
IS_DENSE=ctx.is_dense,
num_warps=num_warps(ctx.maxlut)
da, da.stride(0), #
dout, dout.stride(0), #
out, out.stride(0), #
ctx.scale, #
lut, #
dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2], #
ctx.is_causal, #
BLOCK_SIZE=ctx.block, #
ROW_SIZE=next_power_of_2(ctx.maxlut), #
IS_DENSE=ctx.is_dense, #
num_warps=num_warps(ctx.maxlut) #
)
return (da, None, None, dr, None,
None, None, None, None, None,
None,
None, None, None,
None,
None, None, None
)
return (da, None, None, dr, None, None, None, None, None, None, None, None, None, None, None, None, None, None)
class softmax:
def __init__(self, layout, block, device, is_dense=False):
self.spdims = layout.shape
self.layout = layout
@@ -233,8 +223,6 @@ class softmax:
def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False):
if rel_logits is not None and rel_logits.dtype != a.dtype:
raise ValueError(f"relative position embedding must be {a.dtype}")
a = _softmax.apply(
a, scale, rel_logits, is_causal,
self.spdims, self.block, self.lut, self.maxlut, self.is_dense,
)
a = _softmax.apply(a, scale, rel_logits, is_causal, self.spdims, self.block, self.lut, self.maxlut,
self.is_dense)
return a

View File

@@ -59,6 +59,7 @@ def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
class _cross_entropy(torch.autograd.Function):
@classmethod
def forward(cls, ctx, logits, indices):
# make sure we can use triton

View File

@@ -15,22 +15,19 @@ from .. import language as tl
@jit
def _fwd_kernel(
# fmt: off
Q, K, V, sm_scale,
L,
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
stride_oz, stride_oh, stride_om, stride_on,
Z, H, N_CTX,
Z_H_N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
IS_CAUSAL: tl.constexpr,
# fmt: on
):
def _fwd_kernel(Q, K, V, sm_scale, #
L, #
Out, #
stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vn, stride_vk, #
stride_oz, stride_oh, stride_om, stride_on, #
Z, H, N_CTX, #
Z_H_N_CTX, #
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
BLOCK_N: tl.constexpr, #
IS_CAUSAL: tl.constexpr #
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
qvk_offset = off_hz * stride_qh
@@ -132,27 +129,24 @@ def _bwd_preprocess(
@jit
def _bwd_kernel_one_col_block(
# fmt: off
Q, K, V, sm_scale, qk_scale,
Out, DO,
DQ, DK, DV,
L,
D,
Q_block_ptr, K_block_ptr, V_block_ptr,
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr,
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
Z, H, N_CTX,
off_h, off_z, off_hz, start_n, num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr,
CAUSAL: tl.constexpr,
MMA_V3: tl.constexpr
# fmt: on
):
def _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, #
Out, DO, #
DQ, DK, DV, #
L, #
D, #
Q_block_ptr, K_block_ptr, V_block_ptr, #
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vn, stride_vk, #
Z, H, N_CTX, #
off_h, off_z, off_hz, start_n, num_block, #
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
BLOCK_N: tl.constexpr, #
SEQUENCE_PARALLEL: tl.constexpr, #
CAUSAL: tl.constexpr, #
MMA_V3: tl.constexpr #
):
if CAUSAL:
lo = start_n * BLOCK_M
else:
@@ -235,26 +229,23 @@ def _bwd_kernel_one_col_block(
@jit
def _bwd_kernel(
# fmt: off
Q, K, V, sm_scale,
Out, DO,
DQ, DK, DV,
L,
D,
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
Z, H, N_CTX,
Z_H_N_CTX,
SQ_Z_H_N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr,
CAUSAL: tl.constexpr,
MMA_V3: tl.constexpr
# fmt: on
):
def _bwd_kernel(Q, K, V, sm_scale, #
Out, DO, #
DQ, DK, DV, #
L, #
D, #
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vn, stride_vk, #
Z, H, N_CTX, #
Z_H_N_CTX, #
SQ_Z_H_N_CTX, #
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
BLOCK_N: tl.constexpr, #
SEQUENCE_PARALLEL: tl.constexpr, #
CAUSAL: tl.constexpr, #
MMA_V3: tl.constexpr #
):
qk_scale = sm_scale * 1.44269504
off_hz = tl.program_id(0)
off_z = off_hz // H
@@ -331,51 +322,46 @@ def _bwd_kernel(
num_block_n = tl.cdiv(N_CTX, BLOCK_N)
if not SEQUENCE_PARALLEL:
for start_n in range(0, num_block_n):
_bwd_kernel_one_col_block(
# fmt: off
Q, K, V, sm_scale, qk_scale, Out, DO,
DQ, DK, DV,
L,
D,
Q_block_ptr, K_block_ptr, V_block_ptr,
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr,
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
Z, H, N_CTX,
off_h, off_z, off_hz, start_n, num_block_n,
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_N=BLOCK_N,
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
CAUSAL=CAUSAL,
MMA_V3=MMA_V3
# fmt: on
)
_bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, #
DQ, DK, DV, #
L, #
D, #
Q_block_ptr, K_block_ptr, V_block_ptr, #
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vn, stride_vk, #
Z, H, N_CTX, #
off_h, off_z, off_hz, start_n, num_block_n, #
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, #
BLOCK_N=BLOCK_N, #
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, #
CAUSAL=CAUSAL, #
MMA_V3=MMA_V3 #
)
else:
start_n = tl.program_id(1)
_bwd_kernel_one_col_block(
# fmt: off
Q, K, V, sm_scale, qk_scale, Out, DO,
DQ, DK, DV,
L,
D,
Q_block_ptr, K_block_ptr, V_block_ptr,
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr,
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
Z, H, N_CTX,
off_h, off_z, off_hz, start_n, num_block_n,
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_N=BLOCK_N,
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
CAUSAL=CAUSAL,
MMA_V3=MMA_V3
# fmt: on
)
_bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, #
DQ, DK, DV, #
L, #
D, #
Q_block_ptr, K_block_ptr, V_block_ptr, #
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vn, stride_vk, #
Z, H, N_CTX, #
off_h, off_z, off_hz, start_n, num_block_n, #
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, #
BLOCK_N=BLOCK_N, #
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, #
CAUSAL=CAUSAL, #
MMA_V3=MMA_V3 #
)
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False):
# only support for Ampere now
@@ -393,21 +379,19 @@ class _attention(torch.autograd.Function):
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
_fwd_kernel[grid](
# fmt: off
q, k, v, sm_scale,
L,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
q.shape[0] * q.shape[1] * q.shape[2],
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
IS_CAUSAL=causal,
num_warps=num_warps,
num_stages=4,
# fmt: on
q, k, v, sm_scale, #
L, #
o, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
q.shape[0], q.shape[1], q.shape[2], #
q.shape[0] * q.shape[1] * q.shape[2], #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, #
IS_CAUSAL=causal, #
num_warps=num_warps, #
num_stages=4 #
)
ctx.save_for_backward(q, k, v, o, L)
@@ -429,14 +413,14 @@ class _attention(torch.autograd.Function):
do = do.contiguous()
if sequence_parallel:
replicas = cdiv(seq_len_kv, BLOCK)
new_dq_shape = (replicas,) + q.shape
new_dq_shape = (replicas, ) + q.shape
dq = torch.zeros(new_dq_shape, device=q.device, dtype=q.dtype)
else:
dq = torch.zeros_like(q, dtype=q.dtype)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
delta = torch.empty_like(L)
_bwd_preprocess[(cdiv(q.shape[2], BLOCK) * ctx.grid[1],)](
_bwd_preprocess[(cdiv(q.shape[2], BLOCK) * ctx.grid[1], )](
o,
do,
delta,
@@ -444,26 +428,24 @@ class _attention(torch.autograd.Function):
D_HEAD=ctx.BLOCK_DMODEL,
)
_bwd_kernel[(ctx.grid[1], cdiv(seq_len_kv, BLOCK) if sequence_parallel else 1)](
# fmt: off
q, k, v, ctx.sm_scale,
o, do,
dq, dk, dv,
L,
delta,
o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
q.shape[0] * q.shape[1] * q.shape[2],
cdiv(seq_len_kv, BLOCK) * q.shape[0] * q.shape[1] * q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
SEQUENCE_PARALLEL=sequence_parallel,
CAUSAL=ctx.causal,
MMA_V3=MMA_V3,
num_warps=8,
num_stages=1,
# fmt: on
q, k, v, ctx.sm_scale, #
o, do, #
dq, dk, dv, #
L, #
delta, #
o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
q.shape[0], q.shape[1], q.shape[2], #
q.shape[0] * q.shape[1] * q.shape[2], #
cdiv(seq_len_kv, BLOCK) * q.shape[0] * q.shape[1] * q.shape[2], #
BLOCK_M=BLOCK, BLOCK_N=BLOCK, #
BLOCK_DMODEL=ctx.BLOCK_DMODEL, #
SEQUENCE_PARALLEL=sequence_parallel, #
CAUSAL=ctx.causal, #
MMA_V3=MMA_V3, #
num_warps=8, #
num_stages=1 #
)
if len(dq.shape) == 5:

View File

@@ -37,8 +37,9 @@ def get_configs_io_bound():
num_stages=num_stages, num_warps=num_warps))
# split_k
for split_k in [2, 4, 8, 16]:
configs.append(Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
configs.append(
Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return configs
@@ -69,22 +70,22 @@ def get_configs_io_bound():
prune_configs_by={
'early_config_prune': early_config_prune,
'perf_model': estimate_matmul_time,
'top_k': 10
'top_k': 10,
},
)
@heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@jit
def _kernel(A, B, C, M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
dot_out_dtype: tl.constexpr,
allow_tf32: tl.constexpr,
fp8_fast_accum: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr
def _kernel(A, B, C, M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
dot_out_dtype: tl.constexpr, #
allow_tf32: tl.constexpr, #
fp8_fast_accum: tl.constexpr, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr #
):
# matrix multiplication
pid = tl.program_id(0)
@@ -184,14 +185,15 @@ class _matmul(torch.autograd.Function):
ab_dtype = False
# launch kernel
grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
dot_out_dtype=dot_out_dtype,
allow_tf32=allow_tf32,
fp8_fast_accum=fp8_fast_accum,
GROUP_M=8, AB_DTYPE=ab_dtype)
_kernel[grid](
a, b, c, M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
dot_out_dtype=dot_out_dtype, #
allow_tf32=allow_tf32, #
fp8_fast_accum=fp8_fast_accum, #
GROUP_M=8, AB_DTYPE=ab_dtype)
return c
@staticmethod

View File

@@ -5,8 +5,7 @@ import torch
from .. import cdiv
from .._C.libtriton.triton import runtime
from ..runtime import driver
from ..testing import (get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops,
nvsmi)
from ..testing import (get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops, nvsmi)
def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
@@ -14,7 +13,8 @@ def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
total_warps = num_ctas * min(num_warps, 4)
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, cur_sm_clock, backend, device)
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(
dtype, cur_sm_clock, backend, device)
return tflops
@@ -35,12 +35,12 @@ def get_tflops(backend, device, num_ctas, num_warps, dtype):
def estimate_matmul_time(
# backend, device,
num_warps, num_stages,
A, B, C,
M, N, K,
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K,
debug=False, **kwargs
# backend, device,
num_warps, num_stages, #
A, B, C, #
M, N, K, #
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, #
debug=False, **kwargs #
):
''' return estimated running time in ms
= max(compute, loading) + store '''
@@ -149,8 +149,9 @@ def early_config_prune(configs, named_args):
optimal_num_stages = ldgsts_latency / mma_cycles
# nearest stages, prefer large #stages
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages)
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
nearest = heapq.nsmallest(
2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages)
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
for n in nearest:
pruned_configs.append(n[0])

View File

@@ -1,5 +1,4 @@
from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune,
heuristics)
from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune, heuristics)
from .driver import driver
from .jit import JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret

View File

@@ -9,11 +9,10 @@ from .jit import KernelInterface
class OutOfResources(Exception):
def __init__(self, required, limit, name):
self.message = (
f"out of resource: {name}, Required: {required}, Hardware limit: {limit}. "
+ "Reducing block sizes or `num_stages` may help."
)
self.message = (f"out of resource: {name}, Required: {required}, Hardware limit: {limit}. " +
"Reducing block sizes or `num_stages` may help.")
self.required = required
self.limit = limit
self.name = name
@@ -25,6 +24,7 @@ class OutOfResources(Exception):
class Autotuner(KernelInterface):
def __init__(
self,
fn,
@@ -99,10 +99,8 @@ class Autotuner(KernelInterface):
# as kwargs and by the autotuner
conflicts = meta.keys() & config.kwargs.keys()
if conflicts:
raise ValueError(
f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols."
)
raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols.")
# augment meta-parameters with tunable ones
current = dict(meta, **config.kwargs)
full_nargs = {**self.nargs, **current}
@@ -179,7 +177,8 @@ class Autotuner(KernelInterface):
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {
config: self.perf_model(
config:
self.perf_model(
**self.nargs,
**kwargs,
**config.kwargs,
@@ -296,6 +295,7 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_va
class Heuristics(KernelInterface):
def __init__(self, fn, arg_names, values) -> None:
self.fn = fn
self.values = values

View File

@@ -19,6 +19,7 @@ def default_dump_dir():
class CacheManager(ABC):
def __init__(self, key):
pass
@@ -44,6 +45,7 @@ class CacheManager(ABC):
class FileCacheManager(CacheManager):
def __init__(self, key, override=False, dump=False):
self.key = key
self.lock_path = None

View File

@@ -26,6 +26,7 @@ class DriverBase(metaclass=abc.ABCMeta):
class CudaUtils(object):
def __new__(cls):
if not hasattr(cls, "instance"):
cls.instance = super(CudaUtils, cls).__new__(cls)
@@ -65,6 +66,7 @@ class CudaUtils(object):
class CudaDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, "instance"):
cls.instance = super(CudaDriver, cls).__new__(cls)
@@ -81,6 +83,7 @@ class CudaDriver(DriverBase):
class HIPUtils(object):
def __new__(cls):
if not hasattr(cls, "instance"):
cls.instance = super(HIPUtils, cls).__new__(cls)
@@ -111,6 +114,7 @@ class HIPUtils(object):
class HIPDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, "instance"):
cls.instance = super(HIPDriver, cls).__new__(cls)
@@ -122,6 +126,7 @@ class HIPDriver(DriverBase):
class UnsupportedDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, "instance"):
cls.instance = super(UnsupportedDriver, cls).__new__(cls)
@@ -138,6 +143,7 @@ class UnsupportedDriver(DriverBase):
class LazyProxy:
def __init__(self, init_fn):
self._init_fn = init_fn
self._obj = None

View File

@@ -1,4 +1,5 @@
class OutOfResources(Exception):
def __init__(self, required, limit, name):
self.message = f"out of resource: {name}, " f"Required: {required}, " f"Hardware limit: {limit}"
self.message += ". Reducing block sizes or `num_stages` may help."

View File

@@ -37,6 +37,7 @@ def str_to_ty(name):
class TensorHandle:
def __init__(self, data, dtype):
self.data = data
self.dtype = dtype
@@ -46,6 +47,7 @@ class TensorHandle:
class BlockPointerHandle:
def __init__(self, base, shape, strides, offsets, tensor_shape, order):
self.base = base
self.shape = shape
@@ -72,7 +74,9 @@ class BlockPointerHandle:
def wrap_ret(compute_ret_ty):
def wrapper(fn):
def wrapped(*args, **kwargs):
ret = fn(*args, **kwargs)
return TensorHandle(ret.data, compute_ret_ty(*args, **kwargs))
@@ -83,6 +87,7 @@ def wrap_ret(compute_ret_ty):
class Builder:
def __init__(self) -> None:
self.arch = None
# pass
@@ -280,9 +285,8 @@ class Builder:
dtype_tt = ptr.dtype.element_ty
return TensorHandle(ptr.data + (dtype_tt.primitive_bitwidth // 8) * offset.data.astype(np.uint64), ptr.dtype)
def create_tensor_pointer_load(
self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, is_volatile
):
def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy,
is_volatile):
ptrs, masks = ptr.materialize_pointers(boundary_check)
assert padding_option is None
other = None
@@ -364,9 +368,10 @@ class Builder:
def patch_attr(obj, name, member, builder):
new_member = lambda *args, member=member, **kwargs: (
member(*args, **{k: v for k, v in kwargs.items() if k != "_builder"}, _builder=builder)
)
new_member = lambda *args, member=member, **kwargs: (member(*args, **
{k: v
for k, v in kwargs.items()
if k != "_builder"}, _builder=builder))
setattr(obj, name, new_member)
@@ -412,6 +417,7 @@ def _patch_lang_math(lang, builder):
}
def make_numpy(name):
def impl(*args, **kwargs):
ret_type = args[0].type # TODO: incorrect
ret_dtype = args[0].dtype # TODO: incorrect
@@ -424,14 +430,13 @@ def _patch_lang_math(lang, builder):
return impl
def make_fallback(name):
def fallback(*args, **kwargs):
raise NotImplementedError(
f"""
raise NotImplementedError(f"""
{name} not supported in interpreter mode: no known numpy implementation.
If you think that {name} in fact does have a numpy implementation, please add it
to the mapping in python/triton/interpreter/new_interpreter.py:_patch_lang_math.
"""
)
""")
return fallback
@@ -467,6 +472,7 @@ RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_warp_specializati
class GridExecutor:
def __init__(self, fn, arg_names, grid):
from .jit import _normalize_ty # TODO: modularize
@@ -496,7 +502,7 @@ class GridExecutor:
# iterate through grid
grid = self.grid(args) if callable(self.grid) else self.grid
assert len(grid) <= 3
grid = grid + (1,) * (3 - len(grid))
grid = grid + (1, ) * (3 - len(grid))
builder.set_grid_dim(*grid)
for x in range(grid[0]):
for y in range(grid[1]):
@@ -510,6 +516,7 @@ class GridExecutor:
class InterpretedFunction:
def _patch_lang(self, builder):
lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]]
assert len(lang) == 1, "triton.language must be visible from within jit'd function"

View File

@@ -72,9 +72,8 @@ class DependenciesFinder(ast.NodeVisitor):
lhs = self.visit(node.value)
while isinstance(lhs, ast.Attribute):
lhs = self.visit(lhs.value)
if lhs is None or (
getattr(lhs, "__name__", "") == "triton" or getattr(lhs, "__name__", "").endswith(".triton")
):
if lhs is None or (getattr(lhs, "__name__", "") == "triton"
or getattr(lhs, "__name__", "").endswith(".triton")):
return None
return getattr(lhs, node.attr)
@@ -176,7 +175,7 @@ class KernelArg:
assert not self.param.do_not_specialize
try:
return (self.value.data_ptr() % JITFunction.divisibility == 0,)
return (self.value.data_ptr() % JITFunction.divisibility == 0, )
except AttributeError:
pass
@@ -188,7 +187,7 @@ class KernelArg:
self.value == 1,
)
return (False,)
return (False, )
class KernelInterface(Generic[T]):
@@ -253,10 +252,11 @@ class JITFunction(KernelInterface[T]):
return arg.data_ptr() % JITFunction.divisibility == 0
elif isinstance(arg, int):
return (arg % 16 == 0, arg == 1)
return (arg is None,)
return (arg is None, )
# TODO(jlebar): Fold this into the KernelArg class.
def _get_config(self, *args):
def is_divisible_by_16(x):
if hasattr(x, "data_ptr"):
return x.data_ptr() % JITFunction.divisibility == 0
@@ -279,7 +279,9 @@ class JITFunction(KernelInterface[T]):
if is_divisible_by_16(arg) and not param.do_not_specialize
}
divisible_by_8 = {
param.num for param, arg in zip(self.params, args) if is_divisible_by_8(arg) and not param.do_not_specialize
param.num
for param, arg in zip(self.params, args)
if is_divisible_by_8(arg) and not param.do_not_specialize
}
equal_to_1 = {
param.num
@@ -290,9 +292,10 @@ class JITFunction(KernelInterface[T]):
# TODO: method to collect all folded args
none_args = {param.num for param, arg in zip(self.params, args) if arg is None and not param.do_not_specialize}
ids_of_folded_args = equal_to_1 | none_args
return namedtuple(
"instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"]
)(tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args), tuple(divisible_by_8))
return namedtuple("instance_descriptor",
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])( #
tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args),
tuple(divisible_by_8))
# return _triton.code_gen.instance_descriptor(divisible_by_16,
# equal_to_1)
@@ -356,6 +359,7 @@ class JITFunction(KernelInterface[T]):
key = str(key)
class LegacyCompiler:
def __init__(self, module, name):
self.module = module
self.name = name
@@ -449,9 +453,8 @@ class JITFunction(KernelInterface[T]):
if device_type is None:
device_types = [self._device_of(arg) for arg in non_constexpr_arg_values]
device_types = [_device_type for _device_type in device_types if _device_type != ""]
device_type = self._conclude_device_type(
device_types, [self._pinned_memory_of(arg) for arg in non_constexpr_arg_values]
)
device_type = self._conclude_device_type(device_types,
[self._pinned_memory_of(arg) for arg in non_constexpr_arg_values])
device_backend = None
if device_type not in ["cuda"]:
@@ -498,7 +501,7 @@ class JITFunction(KernelInterface[T]):
# Kernel is not cached; we have to compile.
if key not in self.cache[device]:
configs = (self._get_config(*[arg.value for arg in args]),)
configs = (self._get_config(*[arg.value for arg in args]), )
constants = {
arg.param.num: arg.value
for arg in args
@@ -510,21 +513,23 @@ class JITFunction(KernelInterface[T]):
# Build kernel signature -- doesn't include constexpr arguments.
signature = {
arg.param.num: self._type_of(self._key_of(arg.value)) for arg in args if not arg.param.is_constexpr
arg.param.num: self._type_of(self._key_of(arg.value))
for arg in args
if not arg.param.is_constexpr
}
if self._call_hook(
key,
signature,
device,
constants,
num_warps,
num_ctas,
num_stages,
enable_warp_specialization,
enable_fp_fusion,
extern_libs,
configs,
key,
signature,
device,
constants,
num_warps,
num_ctas,
num_stages,
enable_warp_specialization,
enable_fp_fusion,
extern_libs,
configs,
):
return None
@@ -581,7 +586,7 @@ class JITFunction(KernelInterface[T]):
# function source code (without decorators)
self.src = textwrap.dedent(inspect.getsource(fn))
self.src = self.src[self.src.find("def") :]
self.src = self.src[self.src.find("def"):]
# cache of just-in-time compiled kernels
self.cache = defaultdict(dict)
self.hash = None
@@ -734,6 +739,7 @@ class MockTensor:
class TensorWrapper:
def __init__(self, base, dtype):
self.dtype = dtype
self.base = base

View File

@@ -78,10 +78,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None):
return torch.mean(torch.tensor(ret)).item()
def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
quantiles=None,
fast_flush=True,
return_mode="mean"):
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean"):
assert return_mode in ["min", "max", "mean", "median"]
import torch
"""
@@ -261,6 +258,7 @@ class Benchmark:
class Mark:
def __init__(self, fn, benchmarks):
self.fn = fn
self.benchmarks = benchmarks
@@ -405,12 +403,15 @@ def get_max_tensorcore_tflops(dtype, clock_rate, backend=None, device=None):
tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
return tflops
# create decorator that wraps test function into
# a cuda-memcheck system call
def cuda_memcheck(**target_kwargs):
def decorator(test_fn):
@functools.wraps(test_fn)
def wrapper(*args, **kwargs):
import psutil
@@ -428,7 +429,9 @@ def cuda_memcheck(**target_kwargs):
assert "ERROR SUMMARY: 0 errors" in str(out.stdout)
else:
test_fn(*args, **kwargs)
return wrapper
return decorator
@@ -436,22 +439,18 @@ def cuda_memcheck(**target_kwargs):
def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215):
try:
subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"])
subprocess.check_output(
[
"nvidia-smi",
"-i",
"0",
f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}",
]
)
subprocess.check_output(
[
"nvidia-smi",
"-i",
"0",
f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}",
]
)
subprocess.check_output([
"nvidia-smi",
"-i",
"0",
f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}",
])
subprocess.check_output([
"nvidia-smi",
"-i",
"0",
f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}",
])
cur_sm_clock = nvsmi(["clocks.current.sm"])[0]
cur_mem_clock = nvsmi(["clocks.current.memory"])[0]
assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz"

View File

@@ -141,8 +141,7 @@ class ExternLibrary(ABC):
f.write(file_str)
f.close()
if self._format:
subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file],
stdout=subprocess.PIPE).communicate()
subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file], stdout=subprocess.PIPE).communicate()
subprocess.Popen(["isort", output_file], stdout=subprocess.PIPE).communicate()
@@ -208,56 +207,36 @@ class Libdevice(ExternLibrary):
# Group functions together by renaming.
renaming = {
'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh',
'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': 'add_rn',
'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru',
'dadd_rz': 'add_rz', 'fadd_rz': 'add_rz', 'asinf': 'asin',
'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2',
'atanhf': 'atanh', 'brevll': 'brev', 'cbrtf': 'cbrt',
'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign',
'cosf': 'cos', 'coshf': 'cosh', 'cospif': 'cospi',
'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1',
'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn',
'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', 'ddiv_ru': 'div_ru',
'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf',
'erfcf': 'erfc', 'erfcinvf': 'erfcinv', 'erfcxf': 'erfcx',
'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10',
'exp2f': 'exp2', 'expm1f': 'expm1', 'fabsf': 'abs',
'fabs': 'abs', 'fast_fdividef': 'fast_dividef',
'fdimf': 'fdim', 'ffsll': 'ffs', 'floorf': 'floor',
'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn',
'fmaf_ru': 'fma_ru', 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod',
'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb',
'isinff': 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan',
'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn',
'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint',
'llroundf': 'llround', 'logf': 'log', 'log10f': 'log10',
'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb',
'umax': 'max', 'llmax': 'max', 'ullmax': 'max', 'fmaxf': 'max',
'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min',
'fminf': 'min', 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd',
'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', 'dmul_ru': 'mul_ru',
'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz',
'umul24': 'mul24', 'umulhi': 'mulhi', 'mul64hi': 'mulhi',
'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': 'nextafter',
'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf',
'normcdfinvf': 'normcdfinv', 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow',
'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', 'drcp_rd': 'rcp_rd',
'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru',
'drcp_ru': 'rcp_ru', 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz',
'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot',
'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d',
'roundf': 'round', 'rsqrtf': 'rsqrt', 'frsqrt_rn': 'rsqrt_rn',
'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit',
'signbitd': 'signbit', 'sinf': 'sin', 'sinhf': 'sinh',
'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd',
'dsqrt_rd': 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn',
'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', 'fsqrt_rz': 'sqrt_rz',
'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd',
'fsub_rn': 'sub_rn', 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru',
'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz',
'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc',
'y0f': 'y0', 'y1f': 'y1', 'ynf': 'yn'
'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh', 'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn':
'add_rn', 'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru', 'dadd_rz': 'add_rz', 'fadd_rz':
'add_rz', 'asinf': 'asin', 'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2', 'atanhf': 'atanh',
'brevll': 'brev', 'cbrtf': 'cbrt', 'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign', 'cosf': 'cos',
'coshf': 'cosh', 'cospif': 'cospi', 'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1',
'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn', 'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru',
'ddiv_ru': 'div_ru', 'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf', 'erfcf': 'erfc', 'erfcinvf':
'erfcinv', 'erfcxf': 'erfcx', 'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10', 'exp2f': 'exp2',
'expm1f': 'expm1', 'fabsf': 'abs', 'fabs': 'abs', 'fast_fdividef': 'fast_dividef', 'fdimf': 'fdim', 'ffsll':
'ffs', 'floorf': 'floor', 'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn', 'fmaf_ru': 'fma_ru',
'fmaf_rz': 'fma_rz', 'fmodf': 'fmod', 'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb', 'isinff':
'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan', 'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn',
'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint', 'llroundf': 'llround', 'logf': 'log', 'log10f':
'log10', 'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb', 'umax': 'max', 'llmax': 'max', 'ullmax':
'max', 'fmaxf': 'max', 'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min', 'fminf': 'min',
'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd', 'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn',
'dmul_ru': 'mul_ru', 'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz', 'umul24': 'mul24',
'umulhi': 'mulhi', 'mul64hi': 'mulhi', 'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf':
'nextafter', 'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf', 'normcdfinvf': 'normcdfinv',
'popcll': 'popc', 'powif': 'pow', 'powi': 'pow', 'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd',
'drcp_rd': 'rcp_rd', 'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru', 'drcp_ru': 'rcp_ru',
'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz', 'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot',
'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d', 'roundf': 'round', 'rsqrtf': 'rsqrt',
'frsqrt_rn': 'rsqrt_rn', 'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit', 'signbitd': 'signbit',
'sinf': 'sin', 'sinhf': 'sinh', 'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd', 'dsqrt_rd':
'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn', 'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru',
'fsqrt_rz': 'sqrt_rz', 'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd', 'fsub_rn': 'sub_rn',
'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru', 'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz',
'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc', 'y0f': 'y0', 'y1f': 'y1', 'ynf':
'yn'
}
for symbol in self._symbols.values():
@@ -347,8 +326,7 @@ class LLVMDisassembler:
self._ll_file = "/tmp/extern_lib.ll"
def disasm(self, lib_path: str) -> None:
subprocess.Popen([self._path, lib_path, "-o", self.ll_file],
stdout=subprocess.PIPE).communicate()
subprocess.Popen([self._path, lib_path, "-o", self.ll_file], stdout=subprocess.PIPE).communicate()
@property
def ll_file(self) -> str:

View File

@@ -40,10 +40,13 @@ if __name__ == "__main__":
# command-line arguments
parser = ArgumentParser(description=desc)
parser.add_argument("path", help="Path to Python source containing desired kernel in its scope. File will be executed.")
parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile", required=True)
parser.add_argument("path",
help="Path to Python source containing desired kernel in its scope. File will be executed.")
parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile",
required=True)
parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel")
parser.add_argument("--num-stages", "-ns", type=int, default=3, help="Number of stages (meta-parameter of the kernel)")
parser.add_argument("--num-stages", "-ns", type=int, default=3,
help="Number of stages (meta-parameter of the kernel)")
parser.add_argument("--out-name", "-on", type=str, default=None, help="Out name for the compiled kernel")
parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename")
parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True)
@@ -104,7 +107,8 @@ if __name__ == "__main__":
config = triton.compiler.instance_descriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1)
for i in equal_to_1:
constexprs.update({i: 1})
ccinfo = triton.compile(kernel, signature=signature, constants=constexprs, configs=[config], num_warps=args.num_warps, num_stages=args.num_stages)
ccinfo = triton.compile(kernel, signature=signature, constants=constexprs, configs=[config],
num_warps=args.num_warps, num_stages=args.num_stages)
arg_names = []
arg_types = []
for i in signature.keys():

View File

@@ -27,13 +27,12 @@ class KernelLinkerMeta:
class HeaderParser:
def __init__(self) -> None:
import re
# [kernel_name, c signature]
self.linker_directives = re.compile(
"//[\\s]*tt-linker:[\\s]*([\\w]+):(.+):(.+)"
)
self.linker_directives = re.compile("//[\\s]*tt-linker:[\\s]*([\\w]+):(.+):(.+)")
# [name, hash, suffix]
self.kernel_name = re.compile("^([\\w]+)_([\\w]+)_([\\w]+)$")
# [(type, name)]
@@ -153,9 +152,7 @@ void unload_{meta.orig_kernel_name}();
# generate dispatcher function for kernels with different meta-parameter and constant values
def make_default_algo_kernel(meta: KernelLinkerMeta) -> str:
src = f"CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}){{\n"
src += (
f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n"
)
src += (f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n")
src += "}\n"
return src
@@ -167,28 +164,22 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -
src += f"CUresult {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, {gen_signature(meta)});\n"
src += "\n"
src += (
f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{"
)
src += (f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{")
src += "\n"
for meta in sorted(metas, key=lambda m: -m.num_specs):
cond_fn = (
lambda val, hint: f"({val} % {hint} == 0)"
if hint == 16
else f"({val} == {hint})"
if hint == 1
else None
)
conds = " && ".join(
[
cond_fn(val, hint)
for val, hint in zip(meta.arg_names, meta.sizes)
if hint is not None
]
)
src += (
f" if ({conds})\n" if any(meta.sizes) else "if (1)\n"
) # Edge case where no specializations hence no dispatching required
cond_fn = ( #
lambda val, hint: f"({val} % {hint} == 0)" #
if hint == 16 #
else f"({val} == {hint})" #
if hint == 1 #
else None)
conds = " && ".join([ #
cond_fn(val, hint) #
for val, hint in zip(meta.arg_names, meta.sizes) #
if hint is not None
])
src += (f" if ({conds})\n" if any(meta.sizes) else "if (1)\n"
) # Edge case where no specializations hence no dispatching required
arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1]
src += f" return {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(stream, {', '.join(arg_names)});\n"
src += "\n"
@@ -202,9 +193,7 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -
src += f"void {mode}_{name}() {{"
src += "\n"
for meta in sorted(metas, key=lambda m: -m.num_specs):
src += (
f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n"
)
src += (f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n")
src += "}\n"
return src
@@ -306,10 +295,7 @@ if __name__ == "__main__":
fp.write(out)
# generate source
defs = [
make_kernel_hints_dispatcher(name, meta)
for name, meta in parser.kernels.items()
]
defs = [make_kernel_hints_dispatcher(name, meta) for name, meta in parser.kernels.items()]
names = [name for name in parser.kernels.keys()]
func_pointers_def = make_func_pointers(names, meta)
meta_const_def = make_kernel_meta_const_dispatcher(meta)