mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] CompilationError._format_message issue + tidying (#1362)
- fixed `CompilationError._format_message` fails when `error_message` is a `constexpr` - factored out `_is_constexpr()` checks and `_unwrap_if_constexpr()` idioms - Added `UnsupportedLanguageConstruct` exception, replaced some python builtin exceptions raised in such cases. - Some hardening in `.visit_If()` - cleaner exception handling in `build_triton_ir()`
This commit is contained in:
@@ -16,7 +16,7 @@ import tempfile
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
import setuptools
|
||||
import torch
|
||||
@@ -89,6 +89,21 @@ def mangle_fn(name, arg_tys, constants):
|
||||
return ret
|
||||
|
||||
|
||||
def _is_triton_tensor(o: Any) -> bool:
|
||||
return isinstance(o, triton.language.tensor)
|
||||
|
||||
|
||||
def _is_constexpr(o: Any) -> bool:
|
||||
return isinstance(o, triton.language.constexpr) # TODO: fetch triton.language.constexpr to a global after circular imports untangled, saving getattr
|
||||
|
||||
|
||||
def _unwrap_if_constexpr(o: Any):
|
||||
return o.value if isinstance(o, triton.language.constexpr) else o
|
||||
|
||||
|
||||
_condition_types = {bool, int} # Python types accepted for conditionals inside kernels
|
||||
|
||||
|
||||
class enter_sub_region:
|
||||
def __init__(self, generator: CodeGenerator):
|
||||
self.generator = generator
|
||||
@@ -173,9 +188,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.lscope[name] = value
|
||||
self.local_defs[name] = value
|
||||
|
||||
def is_triton_tensor(self, value):
|
||||
return isinstance(value, triton.language.tensor)
|
||||
|
||||
#
|
||||
# AST visitor
|
||||
#
|
||||
@@ -271,7 +283,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
for i, arg_name in enumerate(arg_names):
|
||||
if i in self.constants:
|
||||
cst = self.constants[i]
|
||||
if not isinstance(cst, triton.language.constexpr):
|
||||
if not _is_constexpr(cst):
|
||||
cst = triton.language.constexpr(self.constants[i])
|
||||
arg_values.append(cst)
|
||||
continue
|
||||
@@ -322,7 +334,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if target in self.lscope:
|
||||
raise ValueError(f'{target} is already defined.'
|
||||
f' constexpr cannot be reassigned.')
|
||||
if not isinstance(value, triton.language.constexpr):
|
||||
if not _is_constexpr(value):
|
||||
value = triton.language.constexpr(value)
|
||||
self.lscope[target] = value
|
||||
return self.lscope[target]
|
||||
@@ -334,7 +346,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
for target in node.targets:
|
||||
_names += [self.visit(target)]
|
||||
if len(_names) > 1:
|
||||
raise NotImplementedError("Multiple assignment is not supported.")
|
||||
raise UnsupportedLanguageConstruct(None, node, "simultaneous multiple assignment is not supported.")
|
||||
names = _names[0]
|
||||
values = self.visit(node.value)
|
||||
if not isinstance(names, tuple):
|
||||
@@ -344,9 +356,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
native_nontensor_types = (triton.language.dtype, )
|
||||
for name, value in zip(names, values):
|
||||
# by default, constexpr are assigned into python variable
|
||||
if isinstance(value, triton.language.constexpr):
|
||||
value = value.value
|
||||
if not isinstance(value, triton.language.tensor) and \
|
||||
value = _unwrap_if_constexpr(value)
|
||||
if not _is_triton_tensor(value) and \
|
||||
not isinstance(value, native_nontensor_types):
|
||||
value = triton.language.core._to_tensor(value, self.builder)
|
||||
self.set_value(name, value)
|
||||
@@ -374,30 +385,27 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
args = [self.visit(x) for x in node.elts]
|
||||
return tuple(args)
|
||||
|
||||
def _apply_binary_method(self, method_name, lhs, rhs):
|
||||
# TODO: raise something meaningful if getattr fails below, esp for reverse method
|
||||
if _is_triton_tensor(lhs):
|
||||
return getattr(lhs, method_name)(rhs, _builder=self.builder)
|
||||
if _is_triton_tensor(rhs):
|
||||
reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name)
|
||||
return getattr(rhs, reverse_method_name)(lhs, _builder=self.builder)
|
||||
return getattr(lhs, method_name)(rhs)
|
||||
|
||||
def visit_BinOp(self, node):
|
||||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.right)
|
||||
fn = {
|
||||
ast.Add: '__add__',
|
||||
ast.Sub: '__sub__',
|
||||
ast.Mult: '__mul__',
|
||||
ast.Div: '__truediv__',
|
||||
ast.FloorDiv: '__floordiv__',
|
||||
ast.Mod: '__mod__',
|
||||
ast.Pow: '__pow__',
|
||||
ast.LShift: '__lshift__',
|
||||
ast.RShift: '__rshift__',
|
||||
ast.BitAnd: '__and__',
|
||||
ast.BitOr: '__or__',
|
||||
ast.BitXor: '__xor__',
|
||||
}[type(node.op)]
|
||||
if self.is_triton_tensor(lhs):
|
||||
return getattr(lhs, fn)(rhs, _builder=self.builder)
|
||||
elif self.is_triton_tensor(rhs):
|
||||
fn = fn[:2] + 'r' + fn[2:]
|
||||
return getattr(rhs, fn)(lhs, _builder=self.builder)
|
||||
else:
|
||||
return getattr(lhs, fn)(rhs)
|
||||
method_name = self._method_name_for_bin_op.get(type(node.op))
|
||||
if method_name is None:
|
||||
raise UnsupportedLanguageConstruct(None, node, "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
||||
return self._apply_binary_method(method_name, lhs, rhs)
|
||||
_method_name_for_bin_op: Dict[Type[ast.operator], str] = {
|
||||
ast.Add: '__add__', ast.Sub: '__sub__', ast.Mult: '__mul__', ast.Div: '__truediv__',
|
||||
ast.FloorDiv: '__floordiv__', ast.Mod: '__mod__', ast.Pow: '__pow__',
|
||||
ast.LShift: '__lshift__', ast.RShift: '__rshift__', ast.BitAnd: '__and__', ast.BitOr: '__or__', ast.BitXor: '__xor__',
|
||||
}
|
||||
|
||||
def visit_then_else_blocks(self, node, liveins, then_block, else_block):
|
||||
# then block
|
||||
@@ -513,15 +521,18 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
def visit_If(self, node):
|
||||
cond = self.visit(node.test)
|
||||
if isinstance(cond, triton.language.tensor):
|
||||
if _is_triton_tensor(cond):
|
||||
cond = cond.to(triton.language.int1, _builder=self.builder)
|
||||
if self.scf_stack or not self.contains_return_op(node):
|
||||
self.visit_if_scf(cond, node)
|
||||
else:
|
||||
self.visit_if_top_level(cond, node)
|
||||
else:
|
||||
if isinstance(cond, triton.language.constexpr):
|
||||
cond = cond.value
|
||||
cond = _unwrap_if_constexpr(cond)
|
||||
if type(cond) not in _condition_types: # not isinstance - we insist the real thing, no subclasses and no ducks
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
|
||||
', '.join(_.__name__ for _ in _condition_types), type(cond).__name__))
|
||||
if cond:
|
||||
self.visit_compound_statement(node.body)
|
||||
else:
|
||||
@@ -538,45 +549,31 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
pass
|
||||
|
||||
def visit_Compare(self, node):
|
||||
assert len(node.comparators) == 1
|
||||
assert len(node.ops) == 1
|
||||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.comparators[0])
|
||||
if isinstance(lhs, triton.language.constexpr):
|
||||
lhs = lhs.value
|
||||
if isinstance(rhs, triton.language.constexpr):
|
||||
rhs = rhs.value
|
||||
if not (len(node.comparators) == 1 and len(node.ops) == 1):
|
||||
raise UnsupportedLanguageConstruct(None, node, "simultaneous multiple comparison is not supported")
|
||||
lhs = _unwrap_if_constexpr(self.visit(node.left))
|
||||
rhs = _unwrap_if_constexpr(self.visit(node.comparators[0]))
|
||||
if type(node.ops[0]) == ast.Is:
|
||||
return triton.language.constexpr(lhs is rhs)
|
||||
if type(node.ops[0]) == ast.IsNot:
|
||||
return triton.language.constexpr(lhs is not rhs)
|
||||
fn = {
|
||||
ast.Eq: '__eq__',
|
||||
ast.NotEq: '__ne__',
|
||||
ast.Lt: '__lt__',
|
||||
ast.LtE: '__le__',
|
||||
ast.Gt: '__gt__',
|
||||
ast.GtE: '__ge__',
|
||||
}[type(node.ops[0])]
|
||||
if self.is_triton_tensor(lhs):
|
||||
return getattr(lhs, fn)(rhs, _builder=self.builder)
|
||||
elif self.is_triton_tensor(rhs):
|
||||
fn = fn[:2] + 'r' + fn[2:]
|
||||
return getattr(rhs, fn)(lhs, _builder=self.builder)
|
||||
else:
|
||||
return getattr(lhs, fn)(rhs)
|
||||
method_name = self._method_name_for_comp_op.get(type(node.ops[0]))
|
||||
if method_name is None:
|
||||
raise UnsupportedLanguageConstruct(None, node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__))
|
||||
return self._apply_binary_method(method_name, lhs, rhs)
|
||||
_method_name_for_comp_op: Dict[Type[ast.cmpop], str] = {
|
||||
ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__'
|
||||
}
|
||||
|
||||
def visit_UnaryOp(self, node):
|
||||
op = self.visit(node.operand)
|
||||
fn = {
|
||||
ast.USub: '__neg__',
|
||||
ast.UAdd: '__pos__',
|
||||
ast.Not: '__not__',
|
||||
ast.Invert: '__invert__',
|
||||
}[type(node.op)]
|
||||
if self.is_triton_tensor(op):
|
||||
fn = self._method_name_for_unary_op.get(type(node.op))
|
||||
if fn is None:
|
||||
raise UnsupportedLanguageConstruct(None, node, "AST unary operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
||||
if _is_triton_tensor(op):
|
||||
return getattr(op, fn)(_builder=self.builder)
|
||||
return getattr(op, fn)()
|
||||
_method_name_for_unary_op: Dict[Type[ast.unaryop], str] = {ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__'}
|
||||
|
||||
def visit_While(self, node):
|
||||
with enter_sub_region(self) as sr:
|
||||
@@ -598,8 +595,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
for name in loop_defs:
|
||||
if name in liveins:
|
||||
# We should not def new constexpr
|
||||
assert self.is_triton_tensor(loop_defs[name])
|
||||
assert self.is_triton_tensor(liveins[name])
|
||||
assert _is_triton_tensor(loop_defs[name])
|
||||
assert _is_triton_tensor(liveins[name])
|
||||
assert loop_defs[name].type == liveins[name].type
|
||||
# these are loop-carried values
|
||||
names.append(name)
|
||||
@@ -657,7 +654,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
assert node.ctx.__class__.__name__ == "Load"
|
||||
lhs = self.visit(node.value)
|
||||
slices = self.visit(node.slice)
|
||||
if self.is_triton_tensor(lhs):
|
||||
if _is_triton_tensor(lhs):
|
||||
return lhs.__getitem__(slices, _builder=self.builder)
|
||||
return lhs[slices]
|
||||
|
||||
@@ -690,7 +687,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1))
|
||||
# handle negative constant step (not supported by scf.for in MLIR)
|
||||
negative_step = False
|
||||
if isinstance(step, triton.language.constexpr) and step.value < 0:
|
||||
if _is_constexpr(step) and step.value < 0:
|
||||
step = triton.language.constexpr(-step.value)
|
||||
negative_step = True
|
||||
lb, ub = ub, lb
|
||||
@@ -734,8 +731,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
names = []
|
||||
for name in self.local_defs:
|
||||
if name in liveins:
|
||||
assert self.is_triton_tensor(self.local_defs[name]), f'{name} is not tensor'
|
||||
assert self.is_triton_tensor(liveins[name])
|
||||
assert _is_triton_tensor(self.local_defs[name]), f'{name} is not tensor'
|
||||
assert _is_triton_tensor(liveins[name])
|
||||
assert self.local_defs[name].type == liveins[name].type,\
|
||||
f'Loop-carried variable {name} has initial type {liveins[name].type} '\
|
||||
f'but is re-assigned to {self.local_defs[name].type} in loop! '\
|
||||
@@ -804,9 +801,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
return triton.language.core.device_assert(test, msg, _builder=self.builder)
|
||||
|
||||
def visit_Call(self, node):
|
||||
fn = self.visit(node.func)
|
||||
if isinstance(fn, triton.language.constexpr):
|
||||
fn = fn.value
|
||||
fn = _unwrap_if_constexpr(self.visit(node.func))
|
||||
|
||||
static_implementation = self.statically_implemented_functions.get(fn)
|
||||
if static_implementation is not None:
|
||||
@@ -821,11 +816,11 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
from inspect import getcallargs
|
||||
args = getcallargs(fn.fn, *args, **kws)
|
||||
args = [args[name] for name in fn.arg_names]
|
||||
args = [arg if isinstance(arg, triton.language.tensor)
|
||||
args = [arg if _is_triton_tensor(arg)
|
||||
else triton.language.constexpr(arg) for arg in args]
|
||||
# generate function def
|
||||
attributes = dict()
|
||||
constexprs = [i for i, arg in enumerate(args) if isinstance(arg, triton.language.constexpr)]
|
||||
constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)]
|
||||
constants = {i: args[i] for i in constexprs}
|
||||
# generate call
|
||||
args = [None if i in constexprs else arg for i, arg in enumerate(args)]
|
||||
@@ -854,32 +849,25 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
for i in range(call_op.get_num_results()):
|
||||
results.append(triton.language.tensor(call_op.get_result(i), callee_ret_type[i]))
|
||||
return tuple(results)
|
||||
if (hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__)) or impl.is_builtin(fn):
|
||||
if (hasattr(fn, '__self__') and _is_triton_tensor(fn.__self__)) or impl.is_builtin(fn):
|
||||
return fn(*args, _builder=self.builder, **kws)
|
||||
if fn in self.builtin_namespace.values():
|
||||
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg for arg in args]
|
||||
args = map(_unwrap_if_constexpr, args)
|
||||
return fn(*args, **kws)
|
||||
|
||||
def visit_Constant(self, node):
|
||||
return triton.language.constexpr(node.value)
|
||||
|
||||
def visit_BoolOp(self, node: ast.BoolOp):
|
||||
assert len(node.values) == 2
|
||||
if len(node.values) != 2:
|
||||
raise UnsupportedLanguageConstruct(None, node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.")
|
||||
lhs = self.visit(node.values[0])
|
||||
rhs = self.visit(node.values[1])
|
||||
|
||||
fn = {
|
||||
ast.And: 'logical_and',
|
||||
ast.Or: 'logical_or',
|
||||
}[type(node.op)]
|
||||
|
||||
if self.is_triton_tensor(lhs):
|
||||
return getattr(lhs, fn)(rhs, _builder=self.builder)
|
||||
elif self.is_triton_tensor(rhs):
|
||||
fn = fn[:2] + 'r' + fn[2:]
|
||||
return getattr(rhs, fn)(lhs, _builder=self.builder)
|
||||
else:
|
||||
return getattr(lhs, fn)(rhs)
|
||||
method_name = self._method_name_for_bool_op.get(type(node.op))
|
||||
if method_name is None:
|
||||
raise UnsupportedLanguageConstruct(None, node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
||||
return self._apply_binary_method(method_name, lhs, rhs)
|
||||
_method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'}
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
def visit_NameConstant(self, node):
|
||||
@@ -893,7 +881,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
lhs = self.visit(node.value)
|
||||
if isinstance(lhs, triton.language.tensor):
|
||||
if _is_triton_tensor(lhs):
|
||||
if node.attr == "T":
|
||||
return triton.language.semantic.trans(lhs, builder=self.builder)
|
||||
return getattr(lhs, node.attr)
|
||||
@@ -912,9 +900,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
elif isinstance(value, ast.FormattedValue):
|
||||
conversion_code = value.conversion
|
||||
evaluated = self.visit(value.value)
|
||||
if not isinstance(evaluated, triton.language.constexpr):
|
||||
raise NotImplementedError("Cannot evaluate f-string containing non-constexpr conversion values,"
|
||||
" found conversion of type " + str(type(evaluated)))
|
||||
if not _is_constexpr(evaluated):
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type " + str(type(evaluated)))
|
||||
values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value)
|
||||
else:
|
||||
raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value)))
|
||||
@@ -931,19 +919,16 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
return super().visit(node)
|
||||
|
||||
def generic_visit(self, node):
|
||||
typename = type(node).__name__
|
||||
raise NotImplementedError("Unsupported node: {}".format(typename))
|
||||
raise UnsupportedLanguageConstruct(None, node, "unsupported AST node type: {}".format(type(node).__name__))
|
||||
|
||||
# TODO: populate this here (rather than inside `_define_name_lookup`) once cyclic imports resolved
|
||||
statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = {}
|
||||
|
||||
def execute_static_print(self, node: ast.Call) -> None:
|
||||
# TODO: too simplistic? Perhaps do something else with non-constexpr
|
||||
def unwrap(_):
|
||||
return _.value if isinstance(_, triton.language.constexpr) else _
|
||||
|
||||
kws = {name: unwrap(value) for name, value in (self.visit(keyword) for keyword in node.keywords)}
|
||||
args = [unwrap(self.visit(arg)) for arg in node.args]
|
||||
kws = {name: _unwrap_if_constexpr(value) for name, value in (self.visit(keyword) for keyword in node.keywords)}
|
||||
args = [_unwrap_if_constexpr(self.visit(arg)) for arg in node.args]
|
||||
print(*args, **kws)
|
||||
|
||||
def execute_static_assert(self, node: ast.Call) -> None:
|
||||
@@ -972,20 +957,29 @@ class CompilationError(Exception):
|
||||
|
||||
def _format_message(self) -> str:
|
||||
node = self.node
|
||||
message = f'at {node.lineno}:{node.col_offset}:'
|
||||
if self.src is None:
|
||||
message += " <source unavailable>"
|
||||
source_excerpt = " <source unavailable>"
|
||||
else:
|
||||
message += '\n'.join(self.src.split('\n')[:node.lineno][-self.source_line_count_max_in_message:])
|
||||
message += '\n' + ' ' * node.col_offset + '^'
|
||||
source_excerpt = self.src.split('\n')[:node.lineno][-self.source_line_count_max_in_message:]
|
||||
if source_excerpt:
|
||||
source_excerpt.append(' ' * node.col_offset + '^')
|
||||
source_excerpt = '\n'.join(source_excerpt)
|
||||
else:
|
||||
source_excerpt = " <source empty>"
|
||||
|
||||
message = "at {}:{}:{}".format(node.lineno, node.col_offset, source_excerpt)
|
||||
if self.error_message:
|
||||
message += '\n' + self.error_message
|
||||
return message
|
||||
|
||||
def __init__(self, src: Optional[str], node: ast.AST, error_message: Optional[str]):
|
||||
def __init__(self, src: Optional[str], node: ast.AST, error_message: Union[str, 'triton.language.core.constexpr', None]):
|
||||
self.src = src
|
||||
self.node = node
|
||||
self.error_message = error_message
|
||||
self.error_message = _unwrap_if_constexpr(error_message)
|
||||
self.message = self._format_message()
|
||||
|
||||
def set_source_code(self, src: Optional[str]):
|
||||
self.src = src
|
||||
self.message = self._format_message()
|
||||
|
||||
def __str__(self):
|
||||
@@ -1001,10 +995,11 @@ class CompilationError(Exception):
|
||||
|
||||
class CompileTimeAssertionFailure(CompilationError):
|
||||
"""Specific exception for failed tests in `static_assert` invocations"""
|
||||
pass
|
||||
|
||||
def set_source_code(self, src: Optional[str]):
|
||||
self.src = src
|
||||
self.message = self._format_message()
|
||||
|
||||
class UnsupportedLanguageConstruct(CompilationError):
|
||||
pass
|
||||
|
||||
|
||||
class OutOfResources(Exception):
|
||||
@@ -1069,11 +1064,10 @@ def build_triton_ir(fn, signature, specialization, constants, debug=False):
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, attributes=new_attrs, is_kernel=True, debug=debug)
|
||||
try:
|
||||
generator.visit(fn.parse())
|
||||
except CompileTimeAssertionFailure as e:
|
||||
e.set_source_code(fn.src)
|
||||
except CompilationError as e:
|
||||
if e.src is None:
|
||||
e.set_source_code(fn.src)
|
||||
raise
|
||||
except CompilationError: # (can this ever happen? nobody has access to fn.src except here)
|
||||
raise # unchanged
|
||||
except Exception as e:
|
||||
node = generator.last_node
|
||||
if node is None:
|
||||
|
||||
Reference in New Issue
Block a user