mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] In JITFunction: infer constexpr arg only if annotated as such (#1345)
Fixed `JITFunction.__init__` to mark args as constexpr only when the annotation is actually `tl.constexpr`, rather than treating any annotated arg as constexpr.
This commit is contained in:
@@ -316,7 +316,9 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()}
|
||||
self.__annotations__ = fn.__annotations__
|
||||
# index of constexprs
|
||||
self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()]
|
||||
from triton.language.core import \
|
||||
constexpr # import here rather than at module level due to circular import tangle
|
||||
self.constexprs = [index for index, ty in self.annotations.items() if issubclass(ty, constexpr)]
|
||||
# launcher
|
||||
self.run = self._make_launcher()
|
||||
# re-use docs of wrapped function
|
||||
|
||||
Reference in New Issue
Block a user