mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] CodeGenerator: enhanced (#1355)
Contents of this change to `CodeGenerator`: - addressed mutable default value in constructor (GitHub #1353) - structured and faster name lookup (replaces `.get_value`) - added informative error messages in some places - tidy mechanism for "static" (compile time) functions replaces inline `if ... elif ...` chain in `.visit_Call` - more robust `static_assert` and `static_print` - more informative `CompilationError` display (saves scrolling up through long tracebacks) - dedicated `CompileTimeAssertionFailure` exception for `static_assert` can be specially treated upstream by `Autotuner` to skip configurations that violate constraints (as for `OutOfResources`) --------- Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -16,7 +16,7 @@ import tempfile
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import setuptools
|
||||
import torch
|
||||
@@ -109,10 +109,11 @@ class enter_sub_region:
|
||||
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, function_name, module=None, is_kernel=False, function_types=dict(), debug=False):
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, function_name,
|
||||
module=None, is_kernel=False, function_types: Optional[Dict] = None, debug=False):
|
||||
self.builder = _triton.ir.builder(context)
|
||||
self.module = self.builder.create_module() if module is None else module
|
||||
self.function_ret_types = function_types
|
||||
self.function_ret_types = {} if function_types is None else function_types
|
||||
self.prototype = prototype
|
||||
self.gscope = gscope
|
||||
self.lscope = dict()
|
||||
@@ -122,45 +123,45 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.is_kernel = is_kernel
|
||||
self.last_node = None
|
||||
self.debug = debug
|
||||
self.builtins = {
|
||||
'range': range,
|
||||
'min': triton.language.minimum,
|
||||
'float': float,
|
||||
'int': int,
|
||||
'print': triton.language.core.device_print,
|
||||
'isinstance': isinstance,
|
||||
'getattr': getattr,
|
||||
}
|
||||
self.static_functions = [
|
||||
'static_print', 'static_assert'
|
||||
]
|
||||
self.scf_stack = []
|
||||
# SSA-construction
|
||||
# name => triton.language.tensor
|
||||
self.local_defs: Dict[str, triton.language.tensor] = {}
|
||||
self.global_uses: Dict[str, triton.language.tensor] = {}
|
||||
self.dereference_name: Callable[[str], Any] = self._define_name_lookup()
|
||||
|
||||
def get_value(self, name):
|
||||
''' This function:
|
||||
1. make sure `name` is defined
|
||||
2. if `name` is triton.language.tensor, get stored tensor by calling
|
||||
`self._get_tensor()`
|
||||
'''
|
||||
# search node.id in local scope
|
||||
ret = None
|
||||
if name in self.lscope:
|
||||
ret = self.lscope[name]
|
||||
if name not in self.local_defs:
|
||||
self.global_uses[name] = ret
|
||||
# search node.id in global scope
|
||||
elif name in self.gscope:
|
||||
ret = self.gscope[name]
|
||||
# search node.id in builtins
|
||||
elif name in self.builtins:
|
||||
ret = self.builtins[name]
|
||||
else:
|
||||
raise ValueError(f'{name} is not defined')
|
||||
return ret
|
||||
builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (range, float, int, isinstance, getattr)}
|
||||
|
||||
def _define_name_lookup(self):
|
||||
# TODO: this needs to be moved to class scope when cyclic imports untangled and `triton.language` can be imported at module level
|
||||
self.builtin_namespace.update((
|
||||
('print', triton.language.core.device_print),
|
||||
('min', triton.language.minimum), # TODO: why `min`? if `min`, why not `max`? `sum`? `all`?
|
||||
))
|
||||
# TODO: this needs to be moved to class scope when cyclic imports untangled and `triton.language` can be imported at module level
|
||||
self.statically_implemented_functions.update((
|
||||
(triton.language.core.static_assert, self.execute_static_assert),
|
||||
(triton.language.core.static_print, self.execute_static_print),
|
||||
))
|
||||
|
||||
def local_lookup(name: str, absent):
|
||||
value = self.lscope.get(name, absent) # this needs to be re-fetched from `self` every time, because it gets switched occasionally
|
||||
if value is not absent and name not in self.local_defs:
|
||||
self.global_uses[name] = value
|
||||
return value
|
||||
|
||||
lookup_order = local_lookup, self.gscope.get, self.builtin_namespace.get
|
||||
absent_marker = object()
|
||||
|
||||
def name_lookup(name: str) -> Any:
|
||||
absent = absent_marker
|
||||
for lookup_function in lookup_order:
|
||||
value = lookup_function(name, absent)
|
||||
if value is not absent:
|
||||
return value
|
||||
raise NameError(f'{name} is not defined')
|
||||
|
||||
return name_lookup
|
||||
|
||||
def set_value(self, name: str,
|
||||
value: Union[triton.language.tensor, triton.language.constexpr]) -> None:
|
||||
@@ -328,7 +329,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
_names = []
|
||||
for target in node.targets:
|
||||
_names += [self.visit(target)]
|
||||
assert len(_names) == 1
|
||||
if len(_names) > 1:
|
||||
raise NotImplementedError("Multiple assignment is not supported.")
|
||||
names = _names[0]
|
||||
values = self.visit(node.value)
|
||||
if not isinstance(names, tuple):
|
||||
@@ -349,12 +351,12 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
rhs = ast.BinOp(lhs, node.op, node.value)
|
||||
assign = ast.Assign(targets=[node.target], value=rhs)
|
||||
self.visit(assign)
|
||||
return self.get_value(name)
|
||||
return self.dereference_name(name)
|
||||
|
||||
def visit_Name(self, node):
|
||||
if type(node.ctx) == ast.Store:
|
||||
return node.id
|
||||
return self.get_value(node.id)
|
||||
return self.dereference_name(node.id)
|
||||
|
||||
def visit_Store(self, node):
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
@@ -671,7 +673,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ast.NodeVisitor.generic_visit(self, stmt)
|
||||
return
|
||||
|
||||
if IteratorClass != self.builtins['range']:
|
||||
if IteratorClass is not range:
|
||||
raise RuntimeError('Only `range` and `static_range` iterators are currently supported')
|
||||
|
||||
# visit iterator arguments
|
||||
@@ -785,8 +787,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def visit_Index(self, node):
|
||||
return self.visit(node.value)
|
||||
|
||||
def visit_keyword(self, node):
|
||||
return {node.arg: self.visit(node.value)}
|
||||
def visit_keyword(self, node) -> Tuple[str, Any]:
|
||||
return node.arg, self.visit(node.value)
|
||||
|
||||
def visit_Assert(self, node) -> Any:
|
||||
if not self.debug:
|
||||
@@ -800,22 +802,16 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
fn = self.visit(node.func)
|
||||
if isinstance(fn, triton.language.constexpr):
|
||||
fn = fn.value
|
||||
kws = dict()
|
||||
for keyword in node.keywords:
|
||||
kws.update(self.visit(keyword))
|
||||
|
||||
static_implementation = self.statically_implemented_functions.get(fn)
|
||||
if static_implementation is not None:
|
||||
return static_implementation(node)
|
||||
|
||||
kws = dict(self.visit(keyword) for keyword in node.keywords)
|
||||
args = [self.visit(arg) for arg in node.args]
|
||||
if fn.__name__ == "print":
|
||||
fn = self.builtins["print"]
|
||||
elif fn.__name__ == "device_assert":
|
||||
if fn is triton.language.core.device_assert: # TODO: this should not be so hardcoded
|
||||
if not self.debug:
|
||||
return
|
||||
elif fn.__name__ in self.static_functions:
|
||||
if fn.__name__ == "static_print":
|
||||
print(*args, **kws)
|
||||
return
|
||||
elif fn.__name__ == "static_assert":
|
||||
assert args[0], args[1]
|
||||
return
|
||||
if isinstance(fn, triton.runtime.JITFunction):
|
||||
from inspect import getcallargs
|
||||
args = getcallargs(fn.fn, *args, **kws)
|
||||
@@ -853,12 +849,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
for i in range(call_op.get_num_results()):
|
||||
results.append(triton.language.tensor(call_op.get_result(i), callee_ret_type[i]))
|
||||
return tuple(results)
|
||||
if (hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__)) \
|
||||
or impl.is_builtin(fn):
|
||||
if (hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__)) or impl.is_builtin(fn):
|
||||
return fn(*args, _builder=self.builder, **kws)
|
||||
if fn in self.builtins.values():
|
||||
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg
|
||||
for arg in args]
|
||||
if fn in self.builtin_namespace.values():
|
||||
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg for arg in args]
|
||||
return fn(*args, **kws)
|
||||
|
||||
def visit_Constant(self, node):
|
||||
@@ -935,19 +929,77 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
typename = type(node).__name__
|
||||
raise NotImplementedError("Unsupported node: {}".format(typename))
|
||||
|
||||
# TODO: populate this here (rather than inside `_define_name_lookup`) once cyclic imports resolved
|
||||
statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = {}
|
||||
|
||||
def execute_static_print(self, node: ast.Call) -> None:
|
||||
# TODO: too simplistic? Perhaps do something else with non-constexpr
|
||||
def unwrap(_):
|
||||
return _.value if isinstance(_, triton.language.constexpr) else _
|
||||
|
||||
kws = {name: unwrap(value) for name, value in (self.visit(keyword) for keyword in node.keywords)}
|
||||
args = [unwrap(self.visit(arg)) for arg in node.args]
|
||||
print(*args, **kws)
|
||||
|
||||
def execute_static_assert(self, node: ast.Call) -> None:
|
||||
arg_count = len(node.args)
|
||||
if not (0 < arg_count <= 2) or len(node.keywords):
|
||||
raise TypeError("`static_assert` requires one or two positional arguments only")
|
||||
|
||||
passed = self.visit(node.args[0])
|
||||
if not isinstance(passed, bool):
|
||||
raise NotImplementedError("Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values")
|
||||
if not passed:
|
||||
if arg_count == 1:
|
||||
message = ""
|
||||
else:
|
||||
try:
|
||||
message = self.visit(node.args[1])
|
||||
except Exception as e:
|
||||
message = "<failed to evaluate assertion message: " + repr(e) + ">"
|
||||
|
||||
raise CompileTimeAssertionFailure(None, node, message)
|
||||
return None
|
||||
|
||||
|
||||
class CompilationError(Exception):
|
||||
def __init__(self, src, node):
|
||||
self.message = f'at {node.lineno}:{node.col_offset}:\n'
|
||||
self.message += '\n'.join(src.split('\n')[:node.lineno])
|
||||
self.message += '\n' + ' ' * node.col_offset + '^'
|
||||
source_line_count_max_in_message = 12
|
||||
|
||||
def _format_message(self) -> str:
|
||||
node = self.node
|
||||
message = f'at {node.lineno}:{node.col_offset}:'
|
||||
if self.src is None:
|
||||
message += " <source unavailable>"
|
||||
else:
|
||||
message += '\n'.join(self.src.split('\n')[:node.lineno][-self.source_line_count_max_in_message:])
|
||||
message += '\n' + ' ' * node.col_offset + '^'
|
||||
if self.error_message:
|
||||
message += '\n' + self.error_message
|
||||
return message
|
||||
|
||||
def __init__(self, src: Optional[str], node: ast.AST, error_message: Optional[str]):
|
||||
self.src = src
|
||||
self.node = node
|
||||
super().__init__(self.message)
|
||||
self.error_message = error_message
|
||||
self.message = self._format_message()
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
def __repr__(self):
|
||||
return "{}({!r})".format(type(self).__name__, self.message)
|
||||
|
||||
def __reduce__(self):
|
||||
# this is necessary to make CompilationError picklable
|
||||
return (type(self), (self.src, self.node))
|
||||
return type(self), (self.src, self.node, self.error_message)
|
||||
|
||||
|
||||
class CompileTimeAssertionFailure(CompilationError):
|
||||
"""Specific exception for failed tests in `static_assert` invocations"""
|
||||
|
||||
def set_source_code(self, src: Optional[str]):
|
||||
self.src = src
|
||||
self.message = self._format_message()
|
||||
|
||||
|
||||
class OutOfResources(Exception):
|
||||
@@ -1012,11 +1064,16 @@ def build_triton_ir(fn, signature, specialization, constants, debug=False):
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, attributes=new_attrs, is_kernel=True, debug=debug)
|
||||
try:
|
||||
generator.visit(fn.parse())
|
||||
except CompileTimeAssertionFailure as e:
|
||||
e.set_source_code(fn.src)
|
||||
raise
|
||||
except CompilationError: # (can this ever happen? nobody has access to fn.src except here)
|
||||
raise # unchanged
|
||||
except Exception as e:
|
||||
node = generator.last_node
|
||||
if node is None or isinstance(e, CompilationError):
|
||||
raise e
|
||||
raise CompilationError(fn.src, node) from e
|
||||
if node is None:
|
||||
raise
|
||||
raise CompilationError(fn.src, node, repr(e)) from e
|
||||
ret = generator.module
|
||||
# module takes ownership of the context
|
||||
ret.context = context
|
||||
|
||||
Reference in New Issue
Block a user