[FRONTEND] Typehint improvement (#1442)

Fixed bug with typehint checking. Refactored typehint code for
specializations. Added typehint checking for sig_keys.
This commit is contained in:
zahimoud
2023-03-29 18:12:40 -07:00
committed by GitHub
parent 43eed392df
commit 3fe2901bfc

View File

@@ -10,6 +10,8 @@ import textwrap
from collections import defaultdict, namedtuple
from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, cast, overload
import torch
import triton
from triton.utils import MockTensor
@@ -234,12 +236,36 @@ class JITFunction(KernelInterface[T]):
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False)
def _get_arg_specialization_key(self, arg) -> str:
arg_annotation = self.__annotations__.get(arg, None)
if not arg_annotation:
return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") \
else ({arg} % {JITFunction.divisibility} == 0, {arg} == 1) if isinstance({arg}, int) \
else (False,)'
elif arg_annotation is torch.Tensor:
return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0)'
elif arg_annotation is int:
return f'({arg} % {JITFunction.divisibility} == 0, {arg} == 1)'
else:
return '(False,)'
def _get_arg_sig_key(self, arg) -> str:
arg_annotation = self.__annotations__.get(arg, None)
if arg_annotation is torch.Tensor:
return f'{arg}.dtype'
elif arg_annotation is bool:
return "i1"
elif arg_annotation is float:
return 'fp32'
else:
return f'_key_of({arg})'
def _make_launcher(self):
regular_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
constexpr_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i in self.constexprs]
args = ', '.join(regular_args)
# cache key for regular argument type
sig_keys = ', '.join([f'_key_of({arg})' for arg in regular_args])
sig_keys = ', '.join([self._get_arg_sig_key(arg) for arg in regular_args])
# cache key for constexpr argument values
constexpr_keys = ', '.join(constexpr_args)
# cache key for argument specialization
@@ -247,17 +273,7 @@ class JITFunction(KernelInterface[T]):
for i, arg in enumerate(regular_args):
if i in self.do_not_specialize:
continue
arg_annotation = self.__annotations__.get(arg, None)
if not arg_annotation:
specializations += [f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") '
f'else ({arg} % {JITFunction.divisibility} == 0, {arg} == 1) if isinstance({arg}, int) '
f'else (False,)']
elif arg_annotation == 'torch.Tensor':
specializations += [f'({arg}.data_ptr() % {JITFunction.divisibility} == 0)']
elif arg_annotation == 'int':
specializations += [f'({arg} % {JITFunction.divisibility} == 0, {arg} == 1)']
else:
specializations += ['(False,)']
specializations += [self._get_arg_specialization_key(arg)]
spec_keys = ', '.join(specializations)
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])