[FRONTEND] code_generator.py TODOs fixed & removed (#1484)

Handled TODOs that were waiting for the circular import issue to be
resolved
This commit is contained in:
mcskatkat
2023-04-08 08:05:46 +03:00
committed by GitHub
parent bc0b007e4b
commit 82ec1a89ea

View File

@@ -5,6 +5,7 @@ import warnings
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
from .. import language
from ..language import constexpr, tensor
# ideally we wouldn't need any runtime component
from ..runtime import JITFunction
from .errors import (CompilationError, CompileTimeAssertionFailure,
@@ -49,15 +50,15 @@ def mangle_fn(name, arg_tys, constants):
def _is_triton_tensor(o: Any) -> bool:
return isinstance(o, language.tensor)
return isinstance(o, tensor)
def _is_constexpr(o: Any) -> bool:
return isinstance(o, language.constexpr) # TODO: fetch language.constexpr to a global after circular imports untangled, saving getattr
return isinstance(o, constexpr)
def _unwrap_if_constexpr(o: Any):
return o.value if isinstance(o, language.constexpr) else o
return o.value if isinstance(o, constexpr) else o
_condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels
@@ -100,24 +101,17 @@ class CodeGenerator(ast.NodeVisitor):
self.scf_stack = []
# SSA-construction
# name => language.tensor
self.local_defs: Dict[str, language.tensor] = {}
self.global_uses: Dict[str, language.tensor] = {}
self.local_defs: Dict[str, tensor] = {}
self.global_uses: Dict[str, tensor] = {}
self.dereference_name: Callable[[str], Any] = self._define_name_lookup()
builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (range, float, int, isinstance, getattr)}
builtin_namespace.update((
('print', language.core.device_print),
('min', language.minimum),
))
def _define_name_lookup(self):
# TODO: this needs to be moved to class scope when cyclic imports untangled and `language` can be imported at module level
self.builtin_namespace.update((
('print', language.core.device_print),
('min', language.minimum), # TODO: why `min`? if `min`, why not `max`? `sum`? `all`?
))
# TODO: this needs to be moved to class scope when cyclic imports untangled and `language` can be imported at module level
self.statically_implemented_functions.update((
(language.core.static_assert, CodeGenerator.execute_static_assert),
(language.core.static_print, CodeGenerator.execute_static_print),
))
def local_lookup(name: str, absent):
value = self.lscope.get(name, absent) # this needs to be re-fetched from `self` every time, because it gets switched occasionally
if value is not absent and name not in self.local_defs:
@@ -127,9 +121,8 @@ class CodeGenerator(ast.NodeVisitor):
absent_marker = object()
def name_lookup(name: str) -> Any:
lookup_order = local_lookup, self.gscope.get, self.builtin_namespace.get
absent = absent_marker
for lookup_function in lookup_order:
for lookup_function in local_lookup, self.gscope.get, self.builtin_namespace.get:
value = lookup_function(name, absent)
if value is not absent:
return value
@@ -138,7 +131,7 @@ class CodeGenerator(ast.NodeVisitor):
return name_lookup
def set_value(self, name: str,
value: Union[language.tensor, language.constexpr]) -> None:
value: Union[tensor, constexpr]) -> None:
''' This function:
called by visit_Assign() & visit_FuncDef() to store left value (lvalue)
1. record local defined name (FIXME: should consider control flow)
@@ -243,13 +236,13 @@ class CodeGenerator(ast.NodeVisitor):
if i in self.constants:
cst = self.constants[i]
if not _is_constexpr(cst):
cst = language.constexpr(self.constants[i])
cst = constexpr(self.constants[i])
arg_values.append(cst)
continue
else:
if i in self.attributes:
fn.set_arg_attr(idx, "tt.divisibility", self.attributes[i][1])
arg_values.append(language.tensor(fn.args(idx), self.prototype.param_types[idx]))
arg_values.append(tensor(fn.args(idx), self.prototype.param_types[idx]))
idx += 1
insert_pt = self.builder.get_insertion_block()
@@ -289,12 +282,12 @@ class CodeGenerator(ast.NodeVisitor):
target = self.visit(node.target)
value = self.visit(node.value)
# constexpr
if annotation == language.constexpr:
if annotation == constexpr:
if target in self.lscope:
raise ValueError(f'{target} is already defined.'
f' constexpr cannot be reassigned.')
if not _is_constexpr(value):
value = language.constexpr(value)
value = constexpr(value)
self.lscope[target] = value
return self.lscope[target]
# default: call visit_Assign
@@ -515,9 +508,9 @@ class CodeGenerator(ast.NodeVisitor):
lhs = _unwrap_if_constexpr(self.visit(node.left))
rhs = _unwrap_if_constexpr(self.visit(node.comparators[0]))
if type(node.ops[0]) == ast.Is:
return language.constexpr(lhs is rhs)
return constexpr(lhs is rhs)
if type(node.ops[0]) == ast.IsNot:
return language.constexpr(lhs is not rhs)
return constexpr(lhs is not rhs)
method_name = self._method_name_for_comp_op.get(type(node.ops[0]))
if method_name is None:
raise UnsupportedLanguageConstruct(None, node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__))
@@ -631,7 +624,7 @@ class CodeGenerator(ast.NodeVisitor):
iterator.end.value,
iterator.step.value)
for i in static_range:
self.lscope[node.target.id] = language.constexpr(i)
self.lscope[node.target.id] = constexpr(i)
self.visit_compound_statement(node.body)
for stmt in node.orelse:
ast.NodeVisitor.generic_visit(self, stmt)
@@ -649,7 +642,7 @@ class CodeGenerator(ast.NodeVisitor):
# handle negative constant step (not supported by scf.for in MLIR)
negative_step = False
if _is_constexpr(step) and step.value < 0:
step = language.constexpr(-step.value)
step = constexpr(-step.value)
negative_step = True
lb, ub = ub, lb
lb = language.core._to_tensor(lb, self.builder)
@@ -779,7 +772,7 @@ class CodeGenerator(ast.NodeVisitor):
args = getcallargs(fn.fn, *args, **kws)
args = [args[name] for name in fn.arg_names]
args = [arg if _is_triton_tensor(arg)
else language.constexpr(arg) for arg in args]
else constexpr(arg) for arg in args]
# generate function def
attributes = dict()
constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)]
@@ -804,12 +797,12 @@ class CodeGenerator(ast.NodeVisitor):
if call_op.get_num_results() == 0 or callee_ret_type is None:
return None
elif call_op.get_num_results() == 1:
return language.tensor(call_op.get_result(0), callee_ret_type)
return tensor(call_op.get_result(0), callee_ret_type)
else:
# should return a tuple of tl.tensor
results = []
for i in range(call_op.get_num_results()):
results.append(language.tensor(call_op.get_result(i), callee_ret_type[i]))
results.append(tensor(call_op.get_result(i), callee_ret_type[i]))
return tuple(results)
if (hasattr(fn, '__self__') and _is_triton_tensor(fn.__self__)) or language.core.is_builtin(fn):
return fn(*args, _builder=self.builder, **kws)
@@ -818,7 +811,7 @@ class CodeGenerator(ast.NodeVisitor):
return fn(*args, **kws)
def visit_Constant(self, node):
return language.constexpr(node.value)
return constexpr(node.value)
def visit_BoolOp(self, node: ast.BoolOp):
if len(node.values) != 2:
@@ -833,13 +826,13 @@ class CodeGenerator(ast.NodeVisitor):
if sys.version_info < (3, 8):
def visit_NameConstant(self, node):
return language.constexpr(node.value)
return constexpr(node.value)
def visit_Num(self, node):
return language.constexpr(node.n)
return constexpr(node.n)
def visit_Str(self, node):
return language.constexpr(ast.literal_eval(node))
return constexpr(ast.literal_eval(node))
def visit_Attribute(self, node):
lhs = self.visit(node.value)
@@ -883,9 +876,6 @@ class CodeGenerator(ast.NodeVisitor):
def generic_visit(self, node):
raise UnsupportedLanguageConstruct(None, node, "unsupported AST node type: {}".format(type(node).__name__))
# TODO: populate this here (rather than inside `_define_name_lookup`) once cyclic imports resolved
statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = {}
def execute_static_print(self, node: ast.Call) -> None:
# TODO: too simplistic? Perhaps do something else with non-constexpr
@@ -913,6 +903,11 @@ class CodeGenerator(ast.NodeVisitor):
raise CompileTimeAssertionFailure(None, node, _unwrap_if_constexpr(message))
return None
statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = {
language.core.static_assert: execute_static_assert,
language.core.static_print: execute_static_print,
}
def str_to_ty(name):
if name[0] == "*":