mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] code_generator.py TODOs fixed & removed (#1484)
Handled TODOs that were waiting for the circular import issue to be resolved
This commit is contained in:
@@ -5,6 +5,7 @@ import warnings
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
from .. import language
|
||||
from ..language import constexpr, tensor
|
||||
# ideally we wouldn't need any runtime component
|
||||
from ..runtime import JITFunction
|
||||
from .errors import (CompilationError, CompileTimeAssertionFailure,
|
||||
@@ -49,15 +50,15 @@ def mangle_fn(name, arg_tys, constants):
|
||||
|
||||
|
||||
def _is_triton_tensor(o: Any) -> bool:
|
||||
return isinstance(o, language.tensor)
|
||||
return isinstance(o, tensor)
|
||||
|
||||
|
||||
def _is_constexpr(o: Any) -> bool:
|
||||
return isinstance(o, language.constexpr) # TODO: fetch language.constexpr to a global after circular imports untangled, saving getattr
|
||||
return isinstance(o, constexpr)
|
||||
|
||||
|
||||
def _unwrap_if_constexpr(o: Any):
|
||||
return o.value if isinstance(o, language.constexpr) else o
|
||||
return o.value if isinstance(o, constexpr) else o
|
||||
|
||||
|
||||
_condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels
|
||||
@@ -100,24 +101,17 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.scf_stack = []
|
||||
# SSA-construction
|
||||
# name => language.tensor
|
||||
self.local_defs: Dict[str, language.tensor] = {}
|
||||
self.global_uses: Dict[str, language.tensor] = {}
|
||||
self.local_defs: Dict[str, tensor] = {}
|
||||
self.global_uses: Dict[str, tensor] = {}
|
||||
self.dereference_name: Callable[[str], Any] = self._define_name_lookup()
|
||||
|
||||
builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (range, float, int, isinstance, getattr)}
|
||||
builtin_namespace.update((
|
||||
('print', language.core.device_print),
|
||||
('min', language.minimum),
|
||||
))
|
||||
|
||||
def _define_name_lookup(self):
|
||||
# TODO: this needs to be moved to class scope when cyclic imports untangled and `language` can be imported at module level
|
||||
self.builtin_namespace.update((
|
||||
('print', language.core.device_print),
|
||||
('min', language.minimum), # TODO: why `min`? if `min`, why not `max`? `sum`? `all`?
|
||||
))
|
||||
# TODO: this needs to be moved to class scope when cyclic imports untangled and `language` can be imported at module level
|
||||
self.statically_implemented_functions.update((
|
||||
(language.core.static_assert, CodeGenerator.execute_static_assert),
|
||||
(language.core.static_print, CodeGenerator.execute_static_print),
|
||||
))
|
||||
|
||||
def local_lookup(name: str, absent):
|
||||
value = self.lscope.get(name, absent) # this needs to be re-fetched from `self` every time, because it gets switched occasionally
|
||||
if value is not absent and name not in self.local_defs:
|
||||
@@ -127,9 +121,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
absent_marker = object()
|
||||
|
||||
def name_lookup(name: str) -> Any:
|
||||
lookup_order = local_lookup, self.gscope.get, self.builtin_namespace.get
|
||||
absent = absent_marker
|
||||
for lookup_function in lookup_order:
|
||||
for lookup_function in local_lookup, self.gscope.get, self.builtin_namespace.get:
|
||||
value = lookup_function(name, absent)
|
||||
if value is not absent:
|
||||
return value
|
||||
@@ -138,7 +131,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
return name_lookup
|
||||
|
||||
def set_value(self, name: str,
|
||||
value: Union[language.tensor, language.constexpr]) -> None:
|
||||
value: Union[tensor, constexpr]) -> None:
|
||||
''' This function:
|
||||
called by visit_Assign() & visit_FuncDef() to store left value (lvalue)
|
||||
1. record local defined name (FIXME: should consider control flow)
|
||||
@@ -243,13 +236,13 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if i in self.constants:
|
||||
cst = self.constants[i]
|
||||
if not _is_constexpr(cst):
|
||||
cst = language.constexpr(self.constants[i])
|
||||
cst = constexpr(self.constants[i])
|
||||
arg_values.append(cst)
|
||||
continue
|
||||
else:
|
||||
if i in self.attributes:
|
||||
fn.set_arg_attr(idx, "tt.divisibility", self.attributes[i][1])
|
||||
arg_values.append(language.tensor(fn.args(idx), self.prototype.param_types[idx]))
|
||||
arg_values.append(tensor(fn.args(idx), self.prototype.param_types[idx]))
|
||||
idx += 1
|
||||
|
||||
insert_pt = self.builder.get_insertion_block()
|
||||
@@ -289,12 +282,12 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
target = self.visit(node.target)
|
||||
value = self.visit(node.value)
|
||||
# constexpr
|
||||
if annotation == language.constexpr:
|
||||
if annotation == constexpr:
|
||||
if target in self.lscope:
|
||||
raise ValueError(f'{target} is already defined.'
|
||||
f' constexpr cannot be reassigned.')
|
||||
if not _is_constexpr(value):
|
||||
value = language.constexpr(value)
|
||||
value = constexpr(value)
|
||||
self.lscope[target] = value
|
||||
return self.lscope[target]
|
||||
# default: call visit_Assign
|
||||
@@ -515,9 +508,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
lhs = _unwrap_if_constexpr(self.visit(node.left))
|
||||
rhs = _unwrap_if_constexpr(self.visit(node.comparators[0]))
|
||||
if type(node.ops[0]) == ast.Is:
|
||||
return language.constexpr(lhs is rhs)
|
||||
return constexpr(lhs is rhs)
|
||||
if type(node.ops[0]) == ast.IsNot:
|
||||
return language.constexpr(lhs is not rhs)
|
||||
return constexpr(lhs is not rhs)
|
||||
method_name = self._method_name_for_comp_op.get(type(node.ops[0]))
|
||||
if method_name is None:
|
||||
raise UnsupportedLanguageConstruct(None, node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__))
|
||||
@@ -631,7 +624,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
iterator.end.value,
|
||||
iterator.step.value)
|
||||
for i in static_range:
|
||||
self.lscope[node.target.id] = language.constexpr(i)
|
||||
self.lscope[node.target.id] = constexpr(i)
|
||||
self.visit_compound_statement(node.body)
|
||||
for stmt in node.orelse:
|
||||
ast.NodeVisitor.generic_visit(self, stmt)
|
||||
@@ -649,7 +642,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# handle negative constant step (not supported by scf.for in MLIR)
|
||||
negative_step = False
|
||||
if _is_constexpr(step) and step.value < 0:
|
||||
step = language.constexpr(-step.value)
|
||||
step = constexpr(-step.value)
|
||||
negative_step = True
|
||||
lb, ub = ub, lb
|
||||
lb = language.core._to_tensor(lb, self.builder)
|
||||
@@ -779,7 +772,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
args = getcallargs(fn.fn, *args, **kws)
|
||||
args = [args[name] for name in fn.arg_names]
|
||||
args = [arg if _is_triton_tensor(arg)
|
||||
else language.constexpr(arg) for arg in args]
|
||||
else constexpr(arg) for arg in args]
|
||||
# generate function def
|
||||
attributes = dict()
|
||||
constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)]
|
||||
@@ -804,12 +797,12 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if call_op.get_num_results() == 0 or callee_ret_type is None:
|
||||
return None
|
||||
elif call_op.get_num_results() == 1:
|
||||
return language.tensor(call_op.get_result(0), callee_ret_type)
|
||||
return tensor(call_op.get_result(0), callee_ret_type)
|
||||
else:
|
||||
# should return a tuple of tl.tensor
|
||||
results = []
|
||||
for i in range(call_op.get_num_results()):
|
||||
results.append(language.tensor(call_op.get_result(i), callee_ret_type[i]))
|
||||
results.append(tensor(call_op.get_result(i), callee_ret_type[i]))
|
||||
return tuple(results)
|
||||
if (hasattr(fn, '__self__') and _is_triton_tensor(fn.__self__)) or language.core.is_builtin(fn):
|
||||
return fn(*args, _builder=self.builder, **kws)
|
||||
@@ -818,7 +811,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
return fn(*args, **kws)
|
||||
|
||||
def visit_Constant(self, node):
|
||||
return language.constexpr(node.value)
|
||||
return constexpr(node.value)
|
||||
|
||||
def visit_BoolOp(self, node: ast.BoolOp):
|
||||
if len(node.values) != 2:
|
||||
@@ -833,13 +826,13 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
def visit_NameConstant(self, node):
|
||||
return language.constexpr(node.value)
|
||||
return constexpr(node.value)
|
||||
|
||||
def visit_Num(self, node):
|
||||
return language.constexpr(node.n)
|
||||
return constexpr(node.n)
|
||||
|
||||
def visit_Str(self, node):
|
||||
return language.constexpr(ast.literal_eval(node))
|
||||
return constexpr(ast.literal_eval(node))
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
lhs = self.visit(node.value)
|
||||
@@ -883,9 +876,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def generic_visit(self, node):
|
||||
raise UnsupportedLanguageConstruct(None, node, "unsupported AST node type: {}".format(type(node).__name__))
|
||||
|
||||
# TODO: populate this here (rather than inside `_define_name_lookup`) once cyclic imports resolved
|
||||
statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = {}
|
||||
|
||||
def execute_static_print(self, node: ast.Call) -> None:
|
||||
# TODO: too simplistic? Perhaps do something else with non-constexpr
|
||||
|
||||
@@ -913,6 +903,11 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
raise CompileTimeAssertionFailure(None, node, _unwrap_if_constexpr(message))
|
||||
return None
|
||||
|
||||
statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = {
|
||||
language.core.static_assert: execute_static_assert,
|
||||
language.core.static_print: execute_static_print,
|
||||
}
|
||||
|
||||
|
||||
def str_to_ty(name):
|
||||
if name[0] == "*":
|
||||
|
||||
Reference in New Issue
Block a user