mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] fix error with -> None return annotation (#1987)
None is not a type, so you get:
```
self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if 'constexpr' in ty]
E TypeError: argument of type 'NoneType' is not iterable
```
Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -122,6 +122,14 @@ def version_key():
|
||||
return '-'.join(TRITON_VERSION) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||
|
||||
|
||||
def _normalize_ty(ty) -> str:
|
||||
if isinstance(ty, type):
|
||||
return ty.__name__
|
||||
elif isinstance(ty, str):
|
||||
return ty
|
||||
return repr(ty)
|
||||
|
||||
|
||||
class KernelInterface(Generic[T]):
|
||||
run: T
|
||||
|
||||
@@ -425,8 +433,7 @@ def {self.fn.__name__}({args_signature}, grid=None, num_warps=4, num_stages=3, e
|
||||
self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug
|
||||
self.noinline = noinline
|
||||
# annotations
|
||||
normalize_ty = lambda ty: ty.__name__ if isinstance(ty, type) else ty
|
||||
self.__annotations__ = {name: normalize_ty(ty) for name, ty in fn.__annotations__.items()}
|
||||
self.__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
|
||||
# index of constexprs
|
||||
self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if 'constexpr' in ty]
|
||||
# launcher
|
||||
|
||||
Reference in New Issue
Block a user