[FRONTEND] CompilationError._format_message issue + tidying (#1362)

- fixed `CompilationError._format_message` fails when `error_message` is
a `constexpr`
- factored out `_is_constexpr()` checks and `_unwrap_if_constexpr()`
idioms
- Added `UnsupportedLanguageConstruct` exception, replaced some python
builtin exceptions raised in such cases.
- Some hardening in `.visit_If()`
- cleaner exception handling in `build_triton_ir()`
This commit is contained in:
mcskatkat
2023-03-21 21:52:18 +02:00
committed by GitHub
parent c1dd6df9ce
commit 9ae78d21f1

View File

@@ -16,7 +16,7 @@ import tempfile
import warnings
from collections import namedtuple
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
import setuptools
import torch
@@ -89,6 +89,21 @@ def mangle_fn(name, arg_tys, constants):
return ret
def _is_triton_tensor(o: Any) -> bool:
return isinstance(o, triton.language.tensor)
def _is_constexpr(o: Any) -> bool:
return isinstance(o, triton.language.constexpr) # TODO: fetch triton.language.constexpr to a global after circular imports untangled, saving getattr
def _unwrap_if_constexpr(o: Any):
return o.value if isinstance(o, triton.language.constexpr) else o
_condition_types = {bool, int} # Python types accepted for conditionals inside kernels
class enter_sub_region:
def __init__(self, generator: CodeGenerator):
self.generator = generator
@@ -173,9 +188,6 @@ class CodeGenerator(ast.NodeVisitor):
self.lscope[name] = value
self.local_defs[name] = value
def is_triton_tensor(self, value):
return isinstance(value, triton.language.tensor)
#
# AST visitor
#
@@ -271,7 +283,7 @@ class CodeGenerator(ast.NodeVisitor):
for i, arg_name in enumerate(arg_names):
if i in self.constants:
cst = self.constants[i]
if not isinstance(cst, triton.language.constexpr):
if not _is_constexpr(cst):
cst = triton.language.constexpr(self.constants[i])
arg_values.append(cst)
continue
@@ -322,7 +334,7 @@ class CodeGenerator(ast.NodeVisitor):
if target in self.lscope:
raise ValueError(f'{target} is already defined.'
f' constexpr cannot be reassigned.')
if not isinstance(value, triton.language.constexpr):
if not _is_constexpr(value):
value = triton.language.constexpr(value)
self.lscope[target] = value
return self.lscope[target]
@@ -334,7 +346,7 @@ class CodeGenerator(ast.NodeVisitor):
for target in node.targets:
_names += [self.visit(target)]
if len(_names) > 1:
raise NotImplementedError("Multiple assignment is not supported.")
raise UnsupportedLanguageConstruct(None, node, "simultaneous multiple assignment is not supported.")
names = _names[0]
values = self.visit(node.value)
if not isinstance(names, tuple):
@@ -344,9 +356,8 @@ class CodeGenerator(ast.NodeVisitor):
native_nontensor_types = (triton.language.dtype, )
for name, value in zip(names, values):
# by default, constexpr are assigned into python variable
if isinstance(value, triton.language.constexpr):
value = value.value
if not isinstance(value, triton.language.tensor) and \
value = _unwrap_if_constexpr(value)
if not _is_triton_tensor(value) and \
not isinstance(value, native_nontensor_types):
value = triton.language.core._to_tensor(value, self.builder)
self.set_value(name, value)
@@ -374,30 +385,27 @@ class CodeGenerator(ast.NodeVisitor):
args = [self.visit(x) for x in node.elts]
return tuple(args)
def _apply_binary_method(self, method_name, lhs, rhs):
# TODO: raise something meaningful if getattr fails below, esp for reverse method
if _is_triton_tensor(lhs):
return getattr(lhs, method_name)(rhs, _builder=self.builder)
if _is_triton_tensor(rhs):
reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name)
return getattr(rhs, reverse_method_name)(lhs, _builder=self.builder)
return getattr(lhs, method_name)(rhs)
def visit_BinOp(self, node):
lhs = self.visit(node.left)
rhs = self.visit(node.right)
fn = {
ast.Add: '__add__',
ast.Sub: '__sub__',
ast.Mult: '__mul__',
ast.Div: '__truediv__',
ast.FloorDiv: '__floordiv__',
ast.Mod: '__mod__',
ast.Pow: '__pow__',
ast.LShift: '__lshift__',
ast.RShift: '__rshift__',
ast.BitAnd: '__and__',
ast.BitOr: '__or__',
ast.BitXor: '__xor__',
}[type(node.op)]
if self.is_triton_tensor(lhs):
return getattr(lhs, fn)(rhs, _builder=self.builder)
elif self.is_triton_tensor(rhs):
fn = fn[:2] + 'r' + fn[2:]
return getattr(rhs, fn)(lhs, _builder=self.builder)
else:
return getattr(lhs, fn)(rhs)
method_name = self._method_name_for_bin_op.get(type(node.op))
if method_name is None:
raise UnsupportedLanguageConstruct(None, node, "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__))
return self._apply_binary_method(method_name, lhs, rhs)
_method_name_for_bin_op: Dict[Type[ast.operator], str] = {
ast.Add: '__add__', ast.Sub: '__sub__', ast.Mult: '__mul__', ast.Div: '__truediv__',
ast.FloorDiv: '__floordiv__', ast.Mod: '__mod__', ast.Pow: '__pow__',
ast.LShift: '__lshift__', ast.RShift: '__rshift__', ast.BitAnd: '__and__', ast.BitOr: '__or__', ast.BitXor: '__xor__',
}
def visit_then_else_blocks(self, node, liveins, then_block, else_block):
# then block
@@ -513,15 +521,18 @@ class CodeGenerator(ast.NodeVisitor):
def visit_If(self, node):
cond = self.visit(node.test)
if isinstance(cond, triton.language.tensor):
if _is_triton_tensor(cond):
cond = cond.to(triton.language.int1, _builder=self.builder)
if self.scf_stack or not self.contains_return_op(node):
self.visit_if_scf(cond, node)
else:
self.visit_if_top_level(cond, node)
else:
if isinstance(cond, triton.language.constexpr):
cond = cond.value
cond = _unwrap_if_constexpr(cond)
if type(cond) not in _condition_types: # not isinstance - we insist the real thing, no subclasses and no ducks
raise UnsupportedLanguageConstruct(
None, node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
', '.join(_.__name__ for _ in _condition_types), type(cond).__name__))
if cond:
self.visit_compound_statement(node.body)
else:
@@ -538,45 +549,31 @@ class CodeGenerator(ast.NodeVisitor):
pass
def visit_Compare(self, node):
assert len(node.comparators) == 1
assert len(node.ops) == 1
lhs = self.visit(node.left)
rhs = self.visit(node.comparators[0])
if isinstance(lhs, triton.language.constexpr):
lhs = lhs.value
if isinstance(rhs, triton.language.constexpr):
rhs = rhs.value
if not (len(node.comparators) == 1 and len(node.ops) == 1):
raise UnsupportedLanguageConstruct(None, node, "simultaneous multiple comparison is not supported")
lhs = _unwrap_if_constexpr(self.visit(node.left))
rhs = _unwrap_if_constexpr(self.visit(node.comparators[0]))
if type(node.ops[0]) == ast.Is:
return triton.language.constexpr(lhs is rhs)
if type(node.ops[0]) == ast.IsNot:
return triton.language.constexpr(lhs is not rhs)
fn = {
ast.Eq: '__eq__',
ast.NotEq: '__ne__',
ast.Lt: '__lt__',
ast.LtE: '__le__',
ast.Gt: '__gt__',
ast.GtE: '__ge__',
}[type(node.ops[0])]
if self.is_triton_tensor(lhs):
return getattr(lhs, fn)(rhs, _builder=self.builder)
elif self.is_triton_tensor(rhs):
fn = fn[:2] + 'r' + fn[2:]
return getattr(rhs, fn)(lhs, _builder=self.builder)
else:
return getattr(lhs, fn)(rhs)
method_name = self._method_name_for_comp_op.get(type(node.ops[0]))
if method_name is None:
raise UnsupportedLanguageConstruct(None, node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__))
return self._apply_binary_method(method_name, lhs, rhs)
_method_name_for_comp_op: Dict[Type[ast.cmpop], str] = {
ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__'
}
def visit_UnaryOp(self, node):
op = self.visit(node.operand)
fn = {
ast.USub: '__neg__',
ast.UAdd: '__pos__',
ast.Not: '__not__',
ast.Invert: '__invert__',
}[type(node.op)]
if self.is_triton_tensor(op):
fn = self._method_name_for_unary_op.get(type(node.op))
if fn is None:
raise UnsupportedLanguageConstruct(None, node, "AST unary operator '{}' is not (currently) implemented.".format(node.op.__name__))
if _is_triton_tensor(op):
return getattr(op, fn)(_builder=self.builder)
return getattr(op, fn)()
_method_name_for_unary_op: Dict[Type[ast.unaryop], str] = {ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__'}
def visit_While(self, node):
with enter_sub_region(self) as sr:
@@ -598,8 +595,8 @@ class CodeGenerator(ast.NodeVisitor):
for name in loop_defs:
if name in liveins:
# We should not def new constexpr
assert self.is_triton_tensor(loop_defs[name])
assert self.is_triton_tensor(liveins[name])
assert _is_triton_tensor(loop_defs[name])
assert _is_triton_tensor(liveins[name])
assert loop_defs[name].type == liveins[name].type
# these are loop-carried values
names.append(name)
@@ -657,7 +654,7 @@ class CodeGenerator(ast.NodeVisitor):
assert node.ctx.__class__.__name__ == "Load"
lhs = self.visit(node.value)
slices = self.visit(node.slice)
if self.is_triton_tensor(lhs):
if _is_triton_tensor(lhs):
return lhs.__getitem__(slices, _builder=self.builder)
return lhs[slices]
@@ -690,7 +687,7 @@ class CodeGenerator(ast.NodeVisitor):
step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1))
# handle negative constant step (not supported by scf.for in MLIR)
negative_step = False
if isinstance(step, triton.language.constexpr) and step.value < 0:
if _is_constexpr(step) and step.value < 0:
step = triton.language.constexpr(-step.value)
negative_step = True
lb, ub = ub, lb
@@ -734,8 +731,8 @@ class CodeGenerator(ast.NodeVisitor):
names = []
for name in self.local_defs:
if name in liveins:
assert self.is_triton_tensor(self.local_defs[name]), f'{name} is not tensor'
assert self.is_triton_tensor(liveins[name])
assert _is_triton_tensor(self.local_defs[name]), f'{name} is not tensor'
assert _is_triton_tensor(liveins[name])
assert self.local_defs[name].type == liveins[name].type,\
f'Loop-carried variable {name} has initial type {liveins[name].type} '\
f'but is re-assigned to {self.local_defs[name].type} in loop! '\
@@ -804,9 +801,7 @@ class CodeGenerator(ast.NodeVisitor):
return triton.language.core.device_assert(test, msg, _builder=self.builder)
def visit_Call(self, node):
fn = self.visit(node.func)
if isinstance(fn, triton.language.constexpr):
fn = fn.value
fn = _unwrap_if_constexpr(self.visit(node.func))
static_implementation = self.statically_implemented_functions.get(fn)
if static_implementation is not None:
@@ -821,11 +816,11 @@ class CodeGenerator(ast.NodeVisitor):
from inspect import getcallargs
args = getcallargs(fn.fn, *args, **kws)
args = [args[name] for name in fn.arg_names]
args = [arg if isinstance(arg, triton.language.tensor)
args = [arg if _is_triton_tensor(arg)
else triton.language.constexpr(arg) for arg in args]
# generate function def
attributes = dict()
constexprs = [i for i, arg in enumerate(args) if isinstance(arg, triton.language.constexpr)]
constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)]
constants = {i: args[i] for i in constexprs}
# generate call
args = [None if i in constexprs else arg for i, arg in enumerate(args)]
@@ -854,32 +849,25 @@ class CodeGenerator(ast.NodeVisitor):
for i in range(call_op.get_num_results()):
results.append(triton.language.tensor(call_op.get_result(i), callee_ret_type[i]))
return tuple(results)
if (hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__)) or impl.is_builtin(fn):
if (hasattr(fn, '__self__') and _is_triton_tensor(fn.__self__)) or impl.is_builtin(fn):
return fn(*args, _builder=self.builder, **kws)
if fn in self.builtin_namespace.values():
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg for arg in args]
args = map(_unwrap_if_constexpr, args)
return fn(*args, **kws)
def visit_Constant(self, node):
return triton.language.constexpr(node.value)
def visit_BoolOp(self, node: ast.BoolOp):
assert len(node.values) == 2
if len(node.values) != 2:
raise UnsupportedLanguageConstruct(None, node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.")
lhs = self.visit(node.values[0])
rhs = self.visit(node.values[1])
fn = {
ast.And: 'logical_and',
ast.Or: 'logical_or',
}[type(node.op)]
if self.is_triton_tensor(lhs):
return getattr(lhs, fn)(rhs, _builder=self.builder)
elif self.is_triton_tensor(rhs):
fn = fn[:2] + 'r' + fn[2:]
return getattr(rhs, fn)(lhs, _builder=self.builder)
else:
return getattr(lhs, fn)(rhs)
method_name = self._method_name_for_bool_op.get(type(node.op))
if method_name is None:
raise UnsupportedLanguageConstruct(None, node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__))
return self._apply_binary_method(method_name, lhs, rhs)
_method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'}
if sys.version_info < (3, 8):
def visit_NameConstant(self, node):
@@ -893,7 +881,7 @@ class CodeGenerator(ast.NodeVisitor):
def visit_Attribute(self, node):
lhs = self.visit(node.value)
if isinstance(lhs, triton.language.tensor):
if _is_triton_tensor(lhs):
if node.attr == "T":
return triton.language.semantic.trans(lhs, builder=self.builder)
return getattr(lhs, node.attr)
@@ -912,9 +900,9 @@ class CodeGenerator(ast.NodeVisitor):
elif isinstance(value, ast.FormattedValue):
conversion_code = value.conversion
evaluated = self.visit(value.value)
if not isinstance(evaluated, triton.language.constexpr):
raise NotImplementedError("Cannot evaluate f-string containing non-constexpr conversion values,"
" found conversion of type " + str(type(evaluated)))
if not _is_constexpr(evaluated):
raise UnsupportedLanguageConstruct(
None, node, "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type " + str(type(evaluated)))
values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value)
else:
raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value)))
@@ -931,19 +919,16 @@ class CodeGenerator(ast.NodeVisitor):
return super().visit(node)
def generic_visit(self, node):
typename = type(node).__name__
raise NotImplementedError("Unsupported node: {}".format(typename))
raise UnsupportedLanguageConstruct(None, node, "unsupported AST node type: {}".format(type(node).__name__))
# TODO: populate this here (rather than inside `_define_name_lookup`) once cyclic imports resolved
statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = {}
def execute_static_print(self, node: ast.Call) -> None:
# TODO: too simplistic? Perhaps do something else with non-constexpr
def unwrap(_):
return _.value if isinstance(_, triton.language.constexpr) else _
kws = {name: unwrap(value) for name, value in (self.visit(keyword) for keyword in node.keywords)}
args = [unwrap(self.visit(arg)) for arg in node.args]
kws = {name: _unwrap_if_constexpr(value) for name, value in (self.visit(keyword) for keyword in node.keywords)}
args = [_unwrap_if_constexpr(self.visit(arg)) for arg in node.args]
print(*args, **kws)
def execute_static_assert(self, node: ast.Call) -> None:
@@ -972,20 +957,29 @@ class CompilationError(Exception):
def _format_message(self) -> str:
node = self.node
message = f'at {node.lineno}:{node.col_offset}:'
if self.src is None:
message += " <source unavailable>"
source_excerpt = " <source unavailable>"
else:
message += '\n'.join(self.src.split('\n')[:node.lineno][-self.source_line_count_max_in_message:])
message += '\n' + ' ' * node.col_offset + '^'
source_excerpt = self.src.split('\n')[:node.lineno][-self.source_line_count_max_in_message:]
if source_excerpt:
source_excerpt.append(' ' * node.col_offset + '^')
source_excerpt = '\n'.join(source_excerpt)
else:
source_excerpt = " <source empty>"
message = "at {}:{}:{}".format(node.lineno, node.col_offset, source_excerpt)
if self.error_message:
message += '\n' + self.error_message
return message
def __init__(self, src: Optional[str], node: ast.AST, error_message: Optional[str]):
def __init__(self, src: Optional[str], node: ast.AST, error_message: Union[str, 'triton.language.core.constexpr', None]):
self.src = src
self.node = node
self.error_message = error_message
self.error_message = _unwrap_if_constexpr(error_message)
self.message = self._format_message()
def set_source_code(self, src: Optional[str]):
self.src = src
self.message = self._format_message()
def __str__(self):
@@ -1001,10 +995,11 @@ class CompilationError(Exception):
class CompileTimeAssertionFailure(CompilationError):
"""Specific exception for failed tests in `static_assert` invocations"""
pass
def set_source_code(self, src: Optional[str]):
self.src = src
self.message = self._format_message()
class UnsupportedLanguageConstruct(CompilationError):
pass
class OutOfResources(Exception):
@@ -1069,11 +1064,10 @@ def build_triton_ir(fn, signature, specialization, constants, debug=False):
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, attributes=new_attrs, is_kernel=True, debug=debug)
try:
generator.visit(fn.parse())
except CompileTimeAssertionFailure as e:
e.set_source_code(fn.src)
except CompilationError as e:
if e.src is None:
e.set_source_code(fn.src)
raise
except CompilationError: # (can this ever happen? nobody has access to fn.src except here)
raise # unchanged
except Exception as e:
node = generator.last_node
if node is None: