[FRONTEND] CodeGenerator: enhanced (#1355)

Contents of this change to `CodeGenerator`:
- addressed mutable default value in constructor (GitHub #1353)
- structured and faster name lookup (replaces `.get_value`)
- added informative error messages in some places
- tidy mechanism for "static" (compile time) functions replaces inline
`if ... elif ...` chain in `.visit_Call`
- more robust `static_assert` and `static_print`
- more informative `CompilationError` display (saves scrolling up
through long tracebacks)
- dedicated `CompileTimeAssertionFailure` exception for `static_assert`
can be specially treated upstream by `Autotuner` to skip configurations
that violate constraints (as for `OutOfResources`)

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
mcskatkat
2023-03-17 02:00:43 +02:00
committed by GitHub
parent ba91f39dbf
commit 611a2dc9bf

View File

@@ -16,7 +16,7 @@ import tempfile
import warnings
from collections import namedtuple
from pathlib import Path
from typing import Any, Callable, Dict, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union
import setuptools
import torch
@@ -109,10 +109,11 @@ class enter_sub_region:
class CodeGenerator(ast.NodeVisitor):
def __init__(self, context, prototype, gscope, attributes, constants, function_name, module=None, is_kernel=False, function_types=dict(), debug=False):
def __init__(self, context, prototype, gscope, attributes, constants, function_name,
module=None, is_kernel=False, function_types: Optional[Dict] = None, debug=False):
self.builder = _triton.ir.builder(context)
self.module = self.builder.create_module() if module is None else module
self.function_ret_types = function_types
self.function_ret_types = {} if function_types is None else function_types
self.prototype = prototype
self.gscope = gscope
self.lscope = dict()
@@ -122,45 +123,45 @@ class CodeGenerator(ast.NodeVisitor):
self.is_kernel = is_kernel
self.last_node = None
self.debug = debug
self.builtins = {
'range': range,
'min': triton.language.minimum,
'float': float,
'int': int,
'print': triton.language.core.device_print,
'isinstance': isinstance,
'getattr': getattr,
}
self.static_functions = [
'static_print', 'static_assert'
]
self.scf_stack = []
# SSA-construction
# name => triton.language.tensor
self.local_defs: Dict[str, triton.language.tensor] = {}
self.global_uses: Dict[str, triton.language.tensor] = {}
self.dereference_name: Callable[[str], Any] = self._define_name_lookup()
def get_value(self, name):
''' This function:
1. make sure `name` is defined
2. if `name` is triton.language.tensor, get stored tensor by calling
`self._get_tensor()`
'''
# search node.id in local scope
ret = None
if name in self.lscope:
ret = self.lscope[name]
if name not in self.local_defs:
self.global_uses[name] = ret
# search node.id in global scope
elif name in self.gscope:
ret = self.gscope[name]
# search node.id in builtins
elif name in self.builtins:
ret = self.builtins[name]
else:
raise ValueError(f'{name} is not defined')
return ret
builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (range, float, int, isinstance, getattr)}
def _define_name_lookup(self):
# TODO: this needs to be moved to class scope when cyclic imports untangled and `triton.language` can be imported at module level
self.builtin_namespace.update((
('print', triton.language.core.device_print),
('min', triton.language.minimum), # TODO: why `min`? if `min`, why not `max`? `sum`? `all`?
))
# TODO: this needs to be moved to class scope when cyclic imports untangled and `triton.language` can be imported at module level
self.statically_implemented_functions.update((
(triton.language.core.static_assert, self.execute_static_assert),
(triton.language.core.static_print, self.execute_static_print),
))
def local_lookup(name: str, absent):
value = self.lscope.get(name, absent) # this needs to be re-fetched from `self` every time, because it gets switched occasionally
if value is not absent and name not in self.local_defs:
self.global_uses[name] = value
return value
lookup_order = local_lookup, self.gscope.get, self.builtin_namespace.get
absent_marker = object()
def name_lookup(name: str) -> Any:
absent = absent_marker
for lookup_function in lookup_order:
value = lookup_function(name, absent)
if value is not absent:
return value
raise NameError(f'{name} is not defined')
return name_lookup
def set_value(self, name: str,
value: Union[triton.language.tensor, triton.language.constexpr]) -> None:
@@ -328,7 +329,8 @@ class CodeGenerator(ast.NodeVisitor):
_names = []
for target in node.targets:
_names += [self.visit(target)]
assert len(_names) == 1
if len(_names) > 1:
raise NotImplementedError("Multiple assignment is not supported.")
names = _names[0]
values = self.visit(node.value)
if not isinstance(names, tuple):
@@ -349,12 +351,12 @@ class CodeGenerator(ast.NodeVisitor):
rhs = ast.BinOp(lhs, node.op, node.value)
assign = ast.Assign(targets=[node.target], value=rhs)
self.visit(assign)
return self.get_value(name)
return self.dereference_name(name)
def visit_Name(self, node):
if type(node.ctx) == ast.Store:
return node.id
return self.get_value(node.id)
return self.dereference_name(node.id)
def visit_Store(self, node):
ast.NodeVisitor.generic_visit(self, node)
@@ -671,7 +673,7 @@ class CodeGenerator(ast.NodeVisitor):
ast.NodeVisitor.generic_visit(self, stmt)
return
if IteratorClass != self.builtins['range']:
if IteratorClass is not range:
raise RuntimeError('Only `range` and `static_range` iterators are currently supported')
# visit iterator arguments
@@ -785,8 +787,8 @@ class CodeGenerator(ast.NodeVisitor):
def visit_Index(self, node):
return self.visit(node.value)
def visit_keyword(self, node):
return {node.arg: self.visit(node.value)}
def visit_keyword(self, node) -> Tuple[str, Any]:
return node.arg, self.visit(node.value)
def visit_Assert(self, node) -> Any:
if not self.debug:
@@ -800,22 +802,16 @@ class CodeGenerator(ast.NodeVisitor):
fn = self.visit(node.func)
if isinstance(fn, triton.language.constexpr):
fn = fn.value
kws = dict()
for keyword in node.keywords:
kws.update(self.visit(keyword))
static_implementation = self.statically_implemented_functions.get(fn)
if static_implementation is not None:
return static_implementation(node)
kws = dict(self.visit(keyword) for keyword in node.keywords)
args = [self.visit(arg) for arg in node.args]
if fn.__name__ == "print":
fn = self.builtins["print"]
elif fn.__name__ == "device_assert":
if fn is triton.language.core.device_assert: # TODO: this should not be so hardcoded
if not self.debug:
return
elif fn.__name__ in self.static_functions:
if fn.__name__ == "static_print":
print(*args, **kws)
return
elif fn.__name__ == "static_assert":
assert args[0], args[1]
return
if isinstance(fn, triton.runtime.JITFunction):
from inspect import getcallargs
args = getcallargs(fn.fn, *args, **kws)
@@ -853,12 +849,10 @@ class CodeGenerator(ast.NodeVisitor):
for i in range(call_op.get_num_results()):
results.append(triton.language.tensor(call_op.get_result(i), callee_ret_type[i]))
return tuple(results)
if (hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__)) \
or impl.is_builtin(fn):
if (hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__)) or impl.is_builtin(fn):
return fn(*args, _builder=self.builder, **kws)
if fn in self.builtins.values():
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg
for arg in args]
if fn in self.builtin_namespace.values():
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg for arg in args]
return fn(*args, **kws)
def visit_Constant(self, node):
@@ -935,19 +929,77 @@ class CodeGenerator(ast.NodeVisitor):
typename = type(node).__name__
raise NotImplementedError("Unsupported node: {}".format(typename))
# TODO: populate this here (rather than inside `_define_name_lookup`) once cyclic imports resolved
statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = {}
def execute_static_print(self, node: ast.Call) -> None:
# TODO: too simplistic? Perhaps do something else with non-constexpr
def unwrap(_):
return _.value if isinstance(_, triton.language.constexpr) else _
kws = {name: unwrap(value) for name, value in (self.visit(keyword) for keyword in node.keywords)}
args = [unwrap(self.visit(arg)) for arg in node.args]
print(*args, **kws)
def execute_static_assert(self, node: ast.Call) -> None:
arg_count = len(node.args)
if not (0 < arg_count <= 2) or len(node.keywords):
raise TypeError("`static_assert` requires one or two positional arguments only")
passed = self.visit(node.args[0])
if not isinstance(passed, bool):
raise NotImplementedError("Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values")
if not passed:
if arg_count == 1:
message = ""
else:
try:
message = self.visit(node.args[1])
except Exception as e:
message = "<failed to evaluate assertion message: " + repr(e) + ">"
raise CompileTimeAssertionFailure(None, node, message)
return None
class CompilationError(Exception):
def __init__(self, src, node):
self.message = f'at {node.lineno}:{node.col_offset}:\n'
self.message += '\n'.join(src.split('\n')[:node.lineno])
self.message += '\n' + ' ' * node.col_offset + '^'
source_line_count_max_in_message = 12
def _format_message(self) -> str:
node = self.node
message = f'at {node.lineno}:{node.col_offset}:'
if self.src is None:
message += " <source unavailable>"
else:
message += '\n'.join(self.src.split('\n')[:node.lineno][-self.source_line_count_max_in_message:])
message += '\n' + ' ' * node.col_offset + '^'
if self.error_message:
message += '\n' + self.error_message
return message
def __init__(self, src: Optional[str], node: ast.AST, error_message: Optional[str]):
self.src = src
self.node = node
super().__init__(self.message)
self.error_message = error_message
self.message = self._format_message()
def __str__(self):
return self.message
def __repr__(self):
return "{}({!r})".format(type(self).__name__, self.message)
def __reduce__(self):
# this is necessary to make CompilationError picklable
return (type(self), (self.src, self.node))
return type(self), (self.src, self.node, self.error_message)
class CompileTimeAssertionFailure(CompilationError):
"""Specific exception for failed tests in `static_assert` invocations"""
def set_source_code(self, src: Optional[str]):
self.src = src
self.message = self._format_message()
class OutOfResources(Exception):
@@ -1012,11 +1064,16 @@ def build_triton_ir(fn, signature, specialization, constants, debug=False):
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, attributes=new_attrs, is_kernel=True, debug=debug)
try:
generator.visit(fn.parse())
except CompileTimeAssertionFailure as e:
e.set_source_code(fn.src)
raise
except CompilationError: # (can this ever happen? nobody has access to fn.src except here)
raise # unchanged
except Exception as e:
node = generator.last_node
if node is None or isinstance(e, CompilationError):
raise e
raise CompilationError(fn.src, node) from e
if node is None:
raise
raise CompilationError(fn.src, node, repr(e)) from e
ret = generator.module
# module takes ownership of the context
ret.context = context