[FRONTEND] Better warning on nested jit functions (#2453)

This commit is contained in:
Keren Zhou
2023-10-06 17:22:51 -04:00
committed by GitHub
parent eed4559df2
commit a42d517021

View File

@@ -228,6 +228,7 @@ class CodeGenerator(ast.NodeVisitor):
self.local_defs: Dict[str, tensor] = {}
self.global_uses: Dict[str, tensor] = {}
self.dereference_name: Callable[[str], Any] = self._define_name_lookup()
self.fn = None
builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (range, float, int, isinstance, getattr)}
builtin_namespace.update((
@@ -322,6 +323,8 @@ class CodeGenerator(ast.NodeVisitor):
def visit_FunctionDef(self, node):
arg_names, kwarg_names = self.visit(node.args)
if self.fn:
raise UnsupportedLanguageConstruct(None, node, "nested function definition is not supported.")
# initialize defaults
for i, default_value in enumerate(node.args.defaults):
arg_node = node.args.args[-i - 1]
@@ -335,9 +338,9 @@ class CodeGenerator(ast.NodeVisitor):
self.visit(init_node)
# initialize function
visibility = "public" if self.is_kernel else "private"
fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility, self.noinline)
self.module.push_back(fn)
entry = fn.add_entry_block()
self.fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility, self.noinline)
self.module.push_back(self.fn)
entry = self.fn.add_entry_block()
arg_values = []
idx = 0
for i, arg_name in enumerate(arg_names):
@@ -350,8 +353,8 @@ class CodeGenerator(ast.NodeVisitor):
else:
if i in self.attributes:
for name, value in self.attributes[i]:
fn.set_arg_attr(idx, name, value)
arg_values.append(tensor(fn.args(idx), self.prototype.param_types[idx]))
self.fn.set_arg_attr(idx, name, value)
arg_values.append(tensor(self.fn.args(idx), self.prototype.param_types[idx]))
idx += 1
insert_pt = self.builder.get_insertion_block()
@@ -367,14 +370,14 @@ class CodeGenerator(ast.NodeVisitor):
# update return type
if isinstance(self.last_ret_type, tuple):
self.prototype.ret_types = list(self.last_ret_type)
fn.reset_type(self.prototype.to_ir(self.builder))
self.fn.reset_type(self.prototype.to_ir(self.builder))
else:
self.prototype.ret_types = [self.last_ret_type]
fn.reset_type(self.prototype.to_ir(self.builder))
self.fn.reset_type(self.prototype.to_ir(self.builder))
if insert_pt:
self.builder.set_insertion_point_to_end(insert_pt)
# Remove dead code
fn.finalize()
self.fn.finalize()
def visit_arguments(self, node):
arg_names = []