mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Better warning on nested jit functions (#2453)
This commit is contained in:
@@ -228,6 +228,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.local_defs: Dict[str, tensor] = {}
|
||||
self.global_uses: Dict[str, tensor] = {}
|
||||
self.dereference_name: Callable[[str], Any] = self._define_name_lookup()
|
||||
self.fn = None
|
||||
|
||||
builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (range, float, int, isinstance, getattr)}
|
||||
builtin_namespace.update((
|
||||
@@ -322,6 +323,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
arg_names, kwarg_names = self.visit(node.args)
|
||||
if self.fn:
|
||||
raise UnsupportedLanguageConstruct(None, node, "nested function definition is not supported.")
|
||||
# initialize defaults
|
||||
for i, default_value in enumerate(node.args.defaults):
|
||||
arg_node = node.args.args[-i - 1]
|
||||
@@ -335,9 +338,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.visit(init_node)
|
||||
# initialize function
|
||||
visibility = "public" if self.is_kernel else "private"
|
||||
fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility, self.noinline)
|
||||
self.module.push_back(fn)
|
||||
entry = fn.add_entry_block()
|
||||
self.fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility, self.noinline)
|
||||
self.module.push_back(self.fn)
|
||||
entry = self.fn.add_entry_block()
|
||||
arg_values = []
|
||||
idx = 0
|
||||
for i, arg_name in enumerate(arg_names):
|
||||
@@ -350,8 +353,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
else:
|
||||
if i in self.attributes:
|
||||
for name, value in self.attributes[i]:
|
||||
fn.set_arg_attr(idx, name, value)
|
||||
arg_values.append(tensor(fn.args(idx), self.prototype.param_types[idx]))
|
||||
self.fn.set_arg_attr(idx, name, value)
|
||||
arg_values.append(tensor(self.fn.args(idx), self.prototype.param_types[idx]))
|
||||
idx += 1
|
||||
|
||||
insert_pt = self.builder.get_insertion_block()
|
||||
@@ -367,14 +370,14 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# update return type
|
||||
if isinstance(self.last_ret_type, tuple):
|
||||
self.prototype.ret_types = list(self.last_ret_type)
|
||||
fn.reset_type(self.prototype.to_ir(self.builder))
|
||||
self.fn.reset_type(self.prototype.to_ir(self.builder))
|
||||
else:
|
||||
self.prototype.ret_types = [self.last_ret_type]
|
||||
fn.reset_type(self.prototype.to_ir(self.builder))
|
||||
self.fn.reset_type(self.prototype.to_ir(self.builder))
|
||||
if insert_pt:
|
||||
self.builder.set_insertion_point_to_end(insert_pt)
|
||||
# Remove dead code
|
||||
fn.finalize()
|
||||
self.fn.finalize()
|
||||
|
||||
def visit_arguments(self, node):
|
||||
arg_names = []
|
||||
|
||||
Reference in New Issue
Block a user