[FRONTEND] Fix specialization on triton integer types (#2236)

https://github.com/openai/triton/issues/2231
This commit is contained in:
Keren Zhou
2023-09-04 02:57:08 -04:00
committed by GitHub
parent a539836876
commit 9e9fbe01f0
2 changed files with 21 additions and 5 deletions

View File

@@ -306,7 +306,7 @@ class JITFunction(KernelInterface[T]):
else (False,)'
elif 'Tensor' in arg_annotation:
return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0)'
elif arg_annotation == 'int':
elif 'int' in arg_annotation or 'bool' in arg_annotation:
return f'({arg} % {JITFunction.divisibility} == 0, {arg} % {JITFunction.divisibility_8} == 0, {arg} == 1)'
else:
return '(False,)'