[FRONTEND] In JITFunction: infer constexpr arg only if annotated as such (#1345)

Fixed `JITFunction.__init__` to mark args as constexpr only when the
annotation is actually `tl.constexpr`, rather than treating any
annotated arg as constexpr.
This commit is contained in:
mcskatkat
2023-03-16 01:39:45 +02:00
committed by GitHub
parent 109b5e2729
commit c175473bbf

View File

@@ -316,7 +316,9 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()}
self.__annotations__ = fn.__annotations__
# index of constexprs
self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()]
from triton.language.core import \
constexpr # import here rather than at module level due to circular import tangle
self.constexprs = [index for index, ty in self.annotations.items() if issubclass(ty, constexpr)]
# launcher
self.run = self._make_launcher()
# re-use docs of wrapped function