[FRONTEND] fix handling of do_not_specialize with interior constantexprs (#2188)

This commit is contained in:
Greg Brockman
2023-08-26 09:19:34 -07:00
committed by GitHub
parent 4ea94a8bcf
commit ab3e8b0dad

View File

@@ -468,9 +468,6 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
self.arg_names = [v.name for v in signature.parameters.values()]
self.arg_defaults = [v.default for v in signature.parameters.values()]
self.has_defaults = any(v != inspect._empty for v in self.arg_defaults)
# specialization hints
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
self.do_not_specialize = {self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize}
# function source code (without decorators)
self.src = textwrap.dedent(inspect.getsource(fn))
self.src = self.src[self.src.find("def"):]
@@ -487,6 +484,10 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
self.__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
# index of constexprs
self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if 'constexpr' in ty]
# specialization hints
regular_args = [arg for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
self.do_not_specialize = {regular_args.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize}
# tma info
self.tensormaps_info = TMAInfos()
# launcher