[FRONTEND] fix error with -> None return annotation (#1987)

None is not a type, so you get:
```
    self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if 'constexpr' in ty]
E   TypeError: argument of type 'NoneType' is not iterable
```

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
Shantanu
2023-07-25 18:49:45 -07:00
committed by GitHub
parent db695c093f
commit 4f1b2ea8d7

View File

@@ -122,6 +122,14 @@ def version_key():
return '-'.join(TRITON_VERSION) + '-' + ptxas_version + '-' + '-'.join(contents)
def _normalize_ty(ty) -> str:
if isinstance(ty, type):
return ty.__name__
elif isinstance(ty, str):
return ty
return repr(ty)
class KernelInterface(Generic[T]):
run: T
@@ -425,8 +433,7 @@ def {self.fn.__name__}({args_signature}, grid=None, num_warps=4, num_stages=3, e
self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug
self.noinline = noinline
# annotations
normalize_ty = lambda ty: ty.__name__ if isinstance(ty, type) else ty
self.__annotations__ = {name: normalize_ty(ty) for name, ty in fn.__annotations__.items()}
self.__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
# index of constexprs
self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if 'constexpr' in ty]
# launcher