mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Typehint improvement (#1442)
Fixed bug with typehint checking. Refactored typehint code for specializations. Added typehint checking for sig_keys.
This commit is contained in:
@@ -10,6 +10,8 @@ import textwrap
|
||||
from collections import defaultdict, namedtuple
|
||||
from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, cast, overload
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
from triton.utils import MockTensor
|
||||
|
||||
@@ -234,12 +236,36 @@ class JITFunction(KernelInterface[T]):
|
||||
|
||||
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False)
|
||||
|
||||
def _get_arg_specialization_key(self, arg) -> str:
|
||||
arg_annotation = self.__annotations__.get(arg, None)
|
||||
if not arg_annotation:
|
||||
return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") \
|
||||
else ({arg} % {JITFunction.divisibility} == 0, {arg} == 1) if isinstance({arg}, int) \
|
||||
else (False,)'
|
||||
elif arg_annotation is torch.Tensor:
|
||||
return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0)'
|
||||
elif arg_annotation is int:
|
||||
return f'({arg} % {JITFunction.divisibility} == 0, {arg} == 1)'
|
||||
else:
|
||||
return '(False,)'
|
||||
|
||||
def _get_arg_sig_key(self, arg) -> str:
|
||||
arg_annotation = self.__annotations__.get(arg, None)
|
||||
if arg_annotation is torch.Tensor:
|
||||
return f'{arg}.dtype'
|
||||
elif arg_annotation is bool:
|
||||
return "i1"
|
||||
elif arg_annotation is float:
|
||||
return 'fp32'
|
||||
else:
|
||||
return f'_key_of({arg})'
|
||||
|
||||
def _make_launcher(self):
|
||||
regular_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
|
||||
constexpr_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i in self.constexprs]
|
||||
args = ', '.join(regular_args)
|
||||
# cache key for regular argument type
|
||||
sig_keys = ', '.join([f'_key_of({arg})' for arg in regular_args])
|
||||
sig_keys = ', '.join([self._get_arg_sig_key(arg) for arg in regular_args])
|
||||
# cache key for constexpr argument values
|
||||
constexpr_keys = ', '.join(constexpr_args)
|
||||
# cache key for argument specialization
|
||||
@@ -247,17 +273,7 @@ class JITFunction(KernelInterface[T]):
|
||||
for i, arg in enumerate(regular_args):
|
||||
if i in self.do_not_specialize:
|
||||
continue
|
||||
arg_annotation = self.__annotations__.get(arg, None)
|
||||
if not arg_annotation:
|
||||
specializations += [f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") '
|
||||
f'else ({arg} % {JITFunction.divisibility} == 0, {arg} == 1) if isinstance({arg}, int) '
|
||||
f'else (False,)']
|
||||
elif arg_annotation == 'torch.Tensor':
|
||||
specializations += [f'({arg}.data_ptr() % {JITFunction.divisibility} == 0)']
|
||||
elif arg_annotation == 'int':
|
||||
specializations += [f'({arg} % {JITFunction.divisibility} == 0, {arg} == 1)']
|
||||
else:
|
||||
specializations += ['(False,)']
|
||||
specializations += [self._get_arg_specialization_key(arg)]
|
||||
|
||||
spec_keys = ', '.join(specializations)
|
||||
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
|
||||
|
||||
Reference in New Issue
Block a user