[FRONTEND] fix for undefined dtypes in jit during loading defaults (#2114)

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
Mohammed Anany
2023-08-25 19:28:23 +02:00
committed by GitHub
parent 56fee37a0d
commit ebfe0ffb29
7 changed files with 217 additions and 213 deletions

View File

@@ -2172,12 +2172,12 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
Y, stride_yk, stride_yn,
W, stride_wn, stride_wl,
Z, stride_zm, stride_zn,
out_dtype: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
ALLOW_TF32: tl.constexpr,
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
COL_A: tl.constexpr, COL_B: tl.constexpr):
COL_A: tl.constexpr, COL_B: tl.constexpr,
out_dtype: tl.constexpr = tl.float32):
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
off_l = tl.arange(0, BLOCK_N)
@@ -2251,7 +2251,6 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
y_tri, y_tri.stride(0), y_tri.stride(1),
w_tri, w_tri.stride(0), w_tri.stride(1),
z_tri, z_tri.stride(0), z_tri.stride(1),
out_dtype,
COL_A=col_a, COL_B=col_b,
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
ADD_MATRIX=epilogue == 'add-matrix',
@@ -2260,7 +2259,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
DO_SOFTMAX=epilogue == 'softmax',
CHAIN_DOT=epilogue == 'chain-dot',
ALLOW_TF32=allow_tf32,
num_warps=num_warps, num_ctas=num_ctas)
num_warps=num_warps, num_ctas=num_ctas,
out_dtype=out_dtype)
if epilogue == 'softmax' and (in_dtype != 'float32' or allow_tf32):
ptx = pgm.asm["ptx"]
start = ptx.find("shfl.sync")

View File

@@ -116,5 +116,3 @@ def test_line_info(func: str):
assert (check_file_lines(file_lines, "standard.py", 33))
assert (check_file_lines(file_lines, "standard.py", 34))
assert (check_file_lines(file_lines, "standard.py", 36))
# core.py is changed frequently, so we only check if it exists
assert (check_file_lines(file_lines, "core.py", -1))

View File

@@ -4,11 +4,21 @@
from . import math
from . import extra
from .standard import (
argmax,
argmin,
cdiv,
cumprod,
cumsum,
max,
maximum,
min,
minimum,
sigmoid,
softmax,
sum,
ravel,
swizzle2d,
xor_sum,
zeros,
zeros_like,
)
@@ -17,8 +27,6 @@ from .core import (
abs,
advance,
arange,
argmin,
argmax,
associative_scan,
atomic_add,
atomic_and,
@@ -35,8 +43,6 @@ from .core import (
cat,
constexpr,
cos,
cumprod,
cumsum,
debug_barrier,
device_assert,
device_print,
@@ -63,12 +69,8 @@ from .core import (
load,
log,
make_block_ptr,
max,
max_constancy,
max_contiguous,
maximum,
min,
minimum,
multiple_of,
num_programs,
pi32_t,
@@ -81,7 +83,6 @@ from .core import (
static_assert,
static_print,
store,
sum,
static_range,
tensor,
trans,
@@ -94,7 +95,6 @@ from .core import (
view,
void,
where,
xor_sum,
)
from .random import (
pair_uniform_to_normal,

View File

@@ -6,8 +6,7 @@ from functools import wraps
from typing import Callable, List, Sequence, TypeVar
from .._C.libtriton.triton import ir
from ..runtime.jit import jit
from . import math, semantic
from . import semantic
T = TypeVar('T')
@@ -205,6 +204,10 @@ class dtype:
def is_bool(self):
return self.is_int1()
@staticmethod
def is_dtype(type_str):
return type_str in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES
@staticmethod
def is_void():
raise RuntimeError("Not implemented")
@@ -1380,170 +1383,6 @@ def _reduce_with_indices(input, axis, combine_fn, _builder=None, _generator=None
return rvalue, rindices
@jit
def minimum(x, y):
"""
Computes the element-wise minimum of :code:`x` and :code:`y`.
:param input: the first input tensor
:type input: Block
:param other: the second input tensor
:type other: Block
"""
return math.min(x, y)
@jit
def maximum(x, y):
"""
Computes the element-wise maximum of :code:`x` and :code:`y`.
:param input: the first input tensor
:type input: Block
:param other: the second input tensor
:type other: Block
"""
return math.max(x, y)
# max and argmax
@jit
def _argmax_combine(value1, index1, value2, index2, tie_break_left):
if tie_break_left:
tie = value1 == value2 and index1 < index2
else:
tie = False
gt = value1 > value2 or tie
v_ret = where(gt, value1, value2)
i_ret = where(gt, index1, index2)
return v_ret, i_ret
@jit
def _argmax_combine_tie_break_left(value1, index1, value2, index2):
return _argmax_combine(value1, index1, value2, index2, True)
@jit
def _argmax_combine_tie_break_fast(value1, index1, value2, index2):
return _argmax_combine(value1, index1, value2, index2, False)
@jit
@_add_reduction_docstr("maximum",
return_indices_arg="return_indices",
tie_break_arg="return_indices_tie_break_left")
def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True):
input = _promote_reduction_input(input)
if return_indices:
if return_indices_tie_break_left:
return _reduce_with_indices(input, axis, _argmax_combine_tie_break_left)
else:
return _reduce_with_indices(input, axis, _argmax_combine_tie_break_fast)
else:
if constexpr(input.dtype.primitive_bitwidth) < 32:
if constexpr(input.dtype.is_floating()):
input = input.to(float32)
else:
assert input.dtype.is_integer_type()
input = input.to(int32)
return reduce(input, axis, maximum)
@jit
@_add_reduction_docstr("maximum index", tie_break_arg="tie_break_left")
def argmax(input, axis, tie_break_left=True):
(_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left)
return ret
# min and argmin
@jit
def _argmin_combine(value1, index1, value2, index2, tie_break_left):
if tie_break_left:
tie = value1 == value2 and index1 < index2
else:
tie = False
lt = value1 < value2 or tie
value_ret = where(lt, value1, value2)
index_ret = where(lt, index1, index2)
return value_ret, index_ret
@jit
def _argmin_combine_tie_break_left(value1, index1, value2, index2):
return _argmin_combine(value1, index1, value2, index2, True)
@jit
def _argmin_combine_tie_break_fast(value1, index1, value2, index2):
return _argmin_combine(value1, index1, value2, index2, False)
@jit
@_add_reduction_docstr("minimum",
return_indices_arg="return_indices",
tie_break_arg="return_indices_tie_break_left")
def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True):
input = _promote_reduction_input(input)
if return_indices:
if return_indices_tie_break_left:
return _reduce_with_indices(input, axis, _argmin_combine_tie_break_left)
else:
return _reduce_with_indices(input, axis, _argmin_combine_tie_break_fast)
else:
if constexpr(input.dtype.primitive_bitwidth) < 32:
if constexpr(input.dtype.is_floating()):
input = input.to(float32)
else:
assert input.dtype.is_integer_type()
input = input.to(int32)
return reduce(input, axis, minimum)
@jit
@_add_reduction_docstr("minimum index",
tie_break_arg="tie_break_left")
def argmin(input, axis, tie_break_left=True):
_, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left)
return ret
@jit
def _sum_combine(a, b):
return a + b
# sum
@jit
@_add_reduction_docstr("sum")
def sum(input, axis=None):
input = _promote_reduction_input(input)
return reduce(input, axis, _sum_combine)
@jit
def _xor_combine(a, b):
return a ^ b
# xor sum
@builtin
@_add_reduction_docstr("xor sum")
def xor_sum(input, axis=None, _builder=None, _generator=None):
scalar_ty = input.type.scalar
if not scalar_ty.is_int():
raise ValueError("xor_sum only supported for integers")
input = _promote_reduction_input(input, _builder=_builder)
return reduce(input, axis, _xor_combine,
_builder=_builder, _generator=_generator)
# -----------------------
# Scans
# -----------------------
@@ -1594,31 +1433,6 @@ def associative_scan(input, axis, combine_fn, _builder=None, _generator=None):
axis = _constexpr_to_value(axis)
return semantic.associative_scan(input, axis, make_combine_region, _builder)
# cumsum
@jit
@_add_scan_docstr("cumsum")
def cumsum(input, axis=0):
# todo rename this to a generic function name
input = _promote_reduction_input(input)
return associative_scan(input, axis, _sum_combine)
# cumprod
@jit
def _prod_combine(a, b):
return a * b
@jit
@_add_scan_docstr("cumprod")
def cumprod(input, axis=0):
# todo rename this to a generic function name
input = _promote_reduction_input(input)
return associative_scan(input, axis, _prod_combine)
# -----------------------
# Compiler Hint Ops
# -----------------------

View File

@@ -1,5 +1,6 @@
from ..runtime.jit import jit
from . import core as tl
from . import standard
PHILOX_KEY_A: tl.constexpr = 0x9E3779B9
PHILOX_KEY_B: tl.constexpr = 0xBB67AE85
@@ -141,7 +142,7 @@ def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
@jit
def pair_uniform_to_normal(u1, u2):
"""Box-Muller transform"""
u1 = tl.maximum(1.0e-7, u1)
u1 = standard.maximum(1.0e-7, u1)
th = 6.283185307179586 * u2
r = tl.sqrt(-2.0 * tl.log(u1))
return r * tl.cos(th), r * tl.sin(th)

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from ..runtime.jit import jit
from . import core
from . import core, math
# -----------------------
# Standard library
@@ -30,9 +30,9 @@ def sigmoid(x):
@jit
@core._add_math_1arg_docstr("softmax")
def softmax(x, ieee_rounding=False):
z = x - core.max(x, 0)
z = x - max(x, 0)
num = core.exp(z)
den = core.sum(num, 0)
den = sum(num, 0)
return core.fdiv(num, den, ieee_rounding)
@@ -73,7 +73,7 @@ def swizzle2d(i, j, size_i, size_j, size_g):
# row-index of the first element of this group
off_i = group_id * size_g
# last group may have fewer rows
size_g = core.minimum(size_i - off_i, size_g)
size_g = minimum(size_i - off_i, size_g)
# new row and column indices
new_i = off_i + (ij % size_g)
new_j = (ij % size_gj) // size_g
@@ -96,3 +96,192 @@ def zeros(shape, dtype):
@jit
def zeros_like(input):
return zeros(input.shape, input.dtype)
@jit
def minimum(x, y):
"""
Computes the element-wise minimum of :code:`x` and :code:`y`.
:param input: the first input tensor
:type input: Block
:param other: the second input tensor
:type other: Block
"""
return math.min(x, y)
@jit
def maximum(x, y):
"""
Computes the element-wise maximum of :code:`x` and :code:`y`.
:param input: the first input tensor
:type input: Block
:param other: the second input tensor
:type other: Block
"""
return math.max(x, y)
# max and argmax
@jit
def _argmax_combine(value1, index1, value2, index2, tie_break_left):
if tie_break_left:
tie = value1 == value2 and index1 < index2
else:
tie = False
gt = value1 > value2 or tie
v_ret = core.where(gt, value1, value2)
i_ret = core.where(gt, index1, index2)
return v_ret, i_ret
@jit
def _argmax_combine_tie_break_left(value1, index1, value2, index2):
return _argmax_combine(value1, index1, value2, index2, True)
@jit
def _argmax_combine_tie_break_fast(value1, index1, value2, index2):
return _argmax_combine(value1, index1, value2, index2, False)
@jit
@core._add_reduction_docstr("maximum",
return_indices_arg="return_indices",
tie_break_arg="return_indices_tie_break_left")
def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True):
input = core._promote_reduction_input(input)
if return_indices:
if return_indices_tie_break_left:
return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_left)
else:
return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_fast)
else:
if core.constexpr(input.dtype.primitive_bitwidth) < 32:
if core.constexpr(input.dtype.is_floating()):
input = input.to(core.float32)
else:
assert input.dtype.is_integer_type()
input = input.to(core.int32)
return core.reduce(input, axis, maximum)
@jit
@core._add_reduction_docstr("maximum index", tie_break_arg="tie_break_left")
def argmax(input, axis, tie_break_left=True):
(_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left)
return ret
# min and argmin
@jit
def _argmin_combine(value1, index1, value2, index2, tie_break_left):
if tie_break_left:
tie = value1 == value2 and index1 < index2
else:
tie = False
lt = value1 < value2 or tie
value_ret = core.where(lt, value1, value2)
index_ret = core.where(lt, index1, index2)
return value_ret, index_ret
@jit
def _argmin_combine_tie_break_left(value1, index1, value2, index2):
return _argmin_combine(value1, index1, value2, index2, True)
@jit
def _argmin_combine_tie_break_fast(value1, index1, value2, index2):
return _argmin_combine(value1, index1, value2, index2, False)
@jit
@core._add_reduction_docstr("minimum",
return_indices_arg="return_indices",
tie_break_arg="return_indices_tie_break_left")
def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True):
input = core._promote_reduction_input(input)
if return_indices:
if return_indices_tie_break_left:
return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_left)
else:
return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_fast)
else:
if core.constexpr(input.dtype.primitive_bitwidth) < 32:
if core.constexpr(input.dtype.is_floating()):
input = input.to(core.float32)
else:
assert input.dtype.is_integer_type()
input = input.to(core.int32)
return core.reduce(input, axis, minimum)
@jit
@core._add_reduction_docstr("minimum index",
tie_break_arg="tie_break_left")
def argmin(input, axis, tie_break_left=True):
_, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left)
return ret
@jit
def _sum_combine(a, b):
return a + b
# sum
@jit
@core._add_reduction_docstr("sum")
def sum(input, axis=None):
input = core._promote_reduction_input(input)
return core.reduce(input, axis, _sum_combine)
@jit
def _xor_combine(a, b):
return a ^ b
# xor sum
@core.builtin
@core._add_reduction_docstr("xor sum")
def xor_sum(input, axis=None, _builder=None, _generator=None):
scalar_ty = input.type.scalar
if not scalar_ty.is_int():
raise ValueError("xor_sum only supported for integers")
input = core._promote_reduction_input(input, _builder=_builder)
return core.reduce(input, axis, _xor_combine,
_builder=_builder, _generator=_generator)
# cumsum
@jit
@core._add_scan_docstr("cumsum")
def cumsum(input, axis=0):
# todo rename this to a generic function name
input = core._promote_reduction_input(input)
return core.associative_scan(input, axis, _sum_combine)
# cumprod
@jit
def _prod_combine(a, b):
return a * b
@jit
@core._add_scan_docstr("cumprod")
def cumprod(input, axis=0):
# todo rename this to a generic function name
input = core._promote_reduction_input(input)
return core.associative_scan(input, axis, _prod_combine)

View File

@@ -13,6 +13,7 @@ from typing import (Callable, Generic, Iterable, List, Optional, TypeVar, Union,
from .._C.libtriton.triton import TMAInfos
from ..common.backend import get_backend, path_to_ptxas
from ..language.core import dtype
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
TRITON_VERSION = "2.1.0"
@@ -358,10 +359,11 @@ class JITFunction(KernelInterface[T]):
spec_keys = ', '.join(specializations)
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
args_signature = ', '.join(name if dflt == inspect._empty else f'{name} = {dflt}' for name, dflt in zip(self.arg_names, self.arg_defaults))
args_signature = ', '.join(name if dflt == inspect._empty else f'{name} = triton.language.dtype(\'{dflt}\')' if dtype.is_dtype(f'{dflt}') else f'{name} = {dflt}' for name, dflt in zip(self.arg_names, self.arg_defaults))
args_signature = args_signature + ', ' if len(args_signature) > 0 else ''
src = f"""
import triton
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
from ..compiler import compile, CompiledKernel, get_arch_default_num_warps, get_arch_default_num_stages
sig_key = {f'{sig_keys},' if len(sig_keys) > 0 else ()}