mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] fix handling of do_not_specialize with interior constantexprs (#2188)
This commit is contained in:
@@ -468,9 +468,6 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
|
||||
self.arg_names = [v.name for v in signature.parameters.values()]
|
||||
self.arg_defaults = [v.default for v in signature.parameters.values()]
|
||||
self.has_defaults = any(v != inspect._empty for v in self.arg_defaults)
|
||||
# specialization hints
|
||||
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
|
||||
self.do_not_specialize = {self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize}
|
||||
# function source code (without decorators)
|
||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||
self.src = self.src[self.src.find("def"):]
|
||||
@@ -487,6 +484,10 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
|
||||
self.__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
|
||||
# index of constexprs
|
||||
self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if 'constexpr' in ty]
|
||||
# specialization hints
|
||||
regular_args = [arg for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
|
||||
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
|
||||
self.do_not_specialize = {regular_args.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize}
|
||||
# tma info
|
||||
self.tensormaps_info = TMAInfos()
|
||||
# launcher
|
||||
|
||||
Reference in New Issue
Block a user