mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
switch const to const_like [run_process_replay] (#6356)
* const like * no more _const * missed one * mypy ops.py * file missing * const_like * fix image and test uop graph [run_process_replay] * fix ptx
This commit is contained in:
@@ -13,7 +13,7 @@ simple_pm = PatternMatcher([
|
||||
(NOp.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
|
||||
(NOp.cvar('x') + NOp.cvar('y'), lambda x,y: UOp.const(dtypes.float, x.arg+y.arg)),
|
||||
(NOp.cvar('x') * NOp.cvar('y') * NOp.cvar('z'), lambda x,y,z: UOp.const(dtypes.float, x.arg*y.arg*z.arg)),
|
||||
((NOp.var('x') + NOp.cvar('c1')) + NOp.cvar('c2'), lambda x,c1,c2: x + x.const(c1.arg+c2.arg)),
|
||||
((NOp.var('x') + NOp.cvar('c1')) + NOp.cvar('c2'), lambda x,c1,c2: x + (c1.arg+c2.arg)),
|
||||
])
|
||||
|
||||
def to_uops_list(u:List[UOp]) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u)))
|
||||
@@ -134,7 +134,7 @@ class TestGraphRewrite(unittest.TestCase):
|
||||
b = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('b', 0, 1))
|
||||
c = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('c', 0, 1))
|
||||
d = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('d', 0, 1))
|
||||
outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b]
|
||||
outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const_like(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b]
|
||||
for out in outs:
|
||||
sink = graph_rewrite(out, constant_folder)
|
||||
print(sink)
|
||||
|
||||
@@ -93,16 +93,16 @@ class IndependentLowerer:
|
||||
idx, valid = x.st_arg.to_indexed_uops(self.ridxs if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL else self.idxs)
|
||||
# TODO: check has_valid in UPat, not here
|
||||
has_valid = valid.op is not UOps.CONST or valid.arg is not True
|
||||
if x.op is UOps.CONST: return valid.where(x.const(x.arg), x.const(0))
|
||||
if x.op is UOps.CONST: return valid.where(x.const_like(x.arg), x.const_like(0))
|
||||
buf = x.src[0]
|
||||
if x.op is UOps.LOAD:
|
||||
barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[2]),)),) if x.src[0].op is UOps.DEFINE_LOCAL else ()
|
||||
return UOp(UOps.LOAD, x.dtype, (buf, idx) + ((x.const(0), valid) if has_valid else ()) + barrier)
|
||||
return UOp(UOps.LOAD, x.dtype, (buf, idx) + ((x.const_like(0), valid) if has_valid else ()) + barrier)
|
||||
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
|
||||
store_back = x.src[0].op is UOps.DEFINE_LOCAL and x.src[2].op is UOps.REDUCE_AXIS and \
|
||||
x.src[2].src[0].op is UOps.LOAD and x.src[2].src[0].src[0].op is UOps.DEFINE_LOCAL
|
||||
# NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes
|
||||
if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const(0) if i in x.src[2].arg[1] else u for i,u in enumerate(self.idxs)])
|
||||
if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if i in x.src[2].arg[1] else u for i,u in enumerate(self.idxs)])
|
||||
if x.src[0].op is UOps.DEFINE_GLOBAL or store_back:
|
||||
for oidx, ridx in zip(self.idxs, self.ridxs):
|
||||
if oidx != ridx: valid = valid * oidx.eq(0)
|
||||
|
||||
@@ -38,7 +38,7 @@ def rintk(d:UOp) -> UOp:
|
||||
"""ceiling(d:float) -> int"""
|
||||
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
return_t = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype]
|
||||
return (d + d.lt(0.0).where(d.const(-0.5), d.const(0.5))).cast(return_t)
|
||||
return (d + d.lt(0.0).where(d.const_like(-0.5), d.const_like(0.5))).cast(return_t)
|
||||
|
||||
def pow2if(q:UOp, float_dtype:DType):
|
||||
"""cast(2^q, float_dtype) where q is any integer in the range of [-126, 127]"""
|
||||
@@ -82,7 +82,7 @@ def frexp(v:UOp) -> Tuple[UOp, UOp]:
|
||||
result_f = bits_to_float((bits & m1) | m2, v.dtype)
|
||||
value = exponent_zero.where(result_f, v)
|
||||
exp = exponent + (-bias)
|
||||
exp = exponent_zero.where(exp, exp.const(0))
|
||||
exp = exponent_zero.where(exp, exp.const_like(0))
|
||||
if v.dtype == dtypes.float16: exp = exp.bitcast(dtypes.int16)
|
||||
return value, exp
|
||||
|
||||
@@ -114,7 +114,7 @@ def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]:
|
||||
def _take(an:UOp, offset:int, count:int=0) -> UOp:
|
||||
"""an = two_over_pi_f[i+offset]"""
|
||||
if count+offset <= len(two_over_pi_f[0:-2]):
|
||||
an = _eq(i, count).where(_take(an, offset, count=count+1), an.const(two_over_pi_f[count+offset]))
|
||||
an = _eq(i, count).where(_take(an, offset, count=count+1), an.const_like(two_over_pi_f[count+offset]))
|
||||
return an
|
||||
def _exact_pow2if(x): return pow2if(x, input_dtype).cast(acc_dtype)
|
||||
def _shl_lazy(x, y): return (x.cast(acc_dtype) * _exact_pow2if(y)).cast(dtypes.uint32)
|
||||
@@ -183,7 +183,7 @@ def trig_poly(d:UOp, coeff32, coeff64):
|
||||
u = __poly8(s, s2, s4, *coeff64[:-1])
|
||||
u = u * s + coeff64[-1]
|
||||
else:
|
||||
u = polyN(s.const(coeff32[0]), s, coeff32[1:])
|
||||
u = polyN(s.const_like(coeff32[0]), s, coeff32[1:])
|
||||
return s * (u * d) + d
|
||||
# approximate sine on [-pi/2, pi/2]
|
||||
def sin_poly(d:UOp) -> UOp:
|
||||
@@ -195,13 +195,13 @@ def sin_poly(d:UOp) -> UOp:
|
||||
def sin_poly_small(d:UOp, q:UOp) -> UOp:
|
||||
def _ifand(n:int): return (q & n).ne(0)
|
||||
r = sin_poly(d)
|
||||
return r * _ifand(1).where(r.const(-1), r.const(1))
|
||||
return r * _ifand(1).where(r.const_like(-1), r.const_like(1))
|
||||
|
||||
def sin_poly_large(d:UOp, q:UOp) -> UOp:
|
||||
def _ifand(n:int): return (q & n).ne(0)
|
||||
d = d + _ifand(1).where(d.const(math.pi / 2), d.const(0))
|
||||
d = d + _ifand(1).where(d.const_like(math.pi / 2), d.const_like(0))
|
||||
r = sin_poly(d)
|
||||
return r * _ifand(2).where(r.const(-1), r.const(1))
|
||||
return r * _ifand(2).where(r.const_like(-1), r.const_like(1))
|
||||
|
||||
# *** toplevel functions for xsin/xlog2/xexp2 ***
|
||||
|
||||
@@ -214,9 +214,9 @@ def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
|
||||
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
reduction_algo = cody_waite_reduction if fast else payne_hanek_reduction
|
||||
# mask +-inf/nan as zero
|
||||
x = _lazy_map_numbers(d, d.const(0.0), d.const(0.0), d.const(0.0), d)
|
||||
x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d)
|
||||
# x_sign = sign(x)
|
||||
x_sign = x.ne(0).where(x.lt(0).where(x.const(-1), x.const(1)), x.const(0))
|
||||
x_sign = x.ne(0).where(x.lt(0).where(x.const_like(-1), x.const_like(1)), x.const_like(0))
|
||||
x_abs = x * x_sign
|
||||
r, q = reduction_algo(x_abs)
|
||||
if fast: result = sin_poly_small(r, q)
|
||||
@@ -229,7 +229,7 @@ def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
|
||||
result = switch_over_map.where(sin_poly_small(r, q), sin_poly_large(r, q))
|
||||
result = result * x_sign # adjusts the sign for abs(x).
|
||||
# sin(Inf) = NaN, sin(-Inf) = NaN, sin(NaN) = NaN
|
||||
return _lazy_map_numbers(d, d.const(math.nan), d.const(math.nan), d.const(math.nan), result)
|
||||
return _lazy_map_numbers(d, d.const_like(math.nan), d.const_like(math.nan), d.const_like(math.nan), result)
|
||||
|
||||
def xexp2(d:UOp) -> UOp:
|
||||
"""
|
||||
@@ -239,30 +239,31 @@ def xexp2(d:UOp) -> UOp:
|
||||
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
fp64_p = d.dtype == dtypes.float64
|
||||
# mask +=inf/nan as zero.
|
||||
x = _lazy_map_numbers(d, d.const(0.0), d.const(0.0), d.const(0.0), d)
|
||||
x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d)
|
||||
q = rintk(x)
|
||||
# s = d - round(d)
|
||||
s = x - q.cast(x.dtype)
|
||||
# a polynomial approximation with 13 non-zero terms in the range of [−(log 2)/2,(log 2)/2].
|
||||
if fp64_p:
|
||||
u = polyN(s.const(0.4434359082926529454e-9), s, [0.7073164598085707425e-8, 0.1017819260921760451e-6, 0.1321543872511327615e-5,
|
||||
0.1525273353517584730e-4, 0.1540353045101147808e-3, 0.1333355814670499073e-2,
|
||||
0.9618129107597600536e-2, 0.5550410866482046596e-1, 0.2402265069591012214e+0,
|
||||
0.6931471805599452862e+0, 0.1000000000000000000e+1])
|
||||
u = polyN(s.const_like(0.4434359082926529454e-9), s, [0.7073164598085707425e-8, 0.1017819260921760451e-6, 0.1321543872511327615e-5,
|
||||
0.1525273353517584730e-4, 0.1540353045101147808e-3, 0.1333355814670499073e-2,
|
||||
0.9618129107597600536e-2, 0.5550410866482046596e-1, 0.2402265069591012214e+0,
|
||||
0.6931471805599452862e+0, 0.1000000000000000000e+1])
|
||||
else:
|
||||
u = polyN(s.const(0.1535920892e-3), s, [0.1339262701e-2, 0.9618384764e-2, 0.5550347269e-1, 0.2402264476e+0, 0.6931471825e+0, 0.1000000000e+1])
|
||||
u = polyN(s.const_like(0.1535920892e-3), s,
|
||||
[0.1339262701e-2, 0.9618384764e-2, 0.5550347269e-1, 0.2402264476e+0, 0.6931471825e+0, 0.1000000000e+1])
|
||||
u = ldexp2k(u, q) # u*2^q
|
||||
upper = {dtypes.float64: 1024, dtypes.float32: 128, dtypes.float16: 23.0}[x.dtype]
|
||||
lower = {dtypes.float64: -2000, dtypes.float32: -150, dtypes.float16: -22}[x.dtype]
|
||||
# Replace x >= upper with +inf
|
||||
u = x.ne(upper).where(u, x.const(math.inf))
|
||||
u = x.lt(upper).where(u, x.const(math.inf))
|
||||
u = x.ne(upper).where(u, x.const_like(math.inf))
|
||||
u = x.lt(upper).where(u, x.const_like(math.inf))
|
||||
# Replace x <= lower with zero.
|
||||
u = x.lt(lower).where(x.const(0.0), u)
|
||||
u = x.lt(lower).where(x.const_like(0.0), u)
|
||||
# x=NaN never satisfies x < Inf. (for fastmode)
|
||||
u = x.lt(math.inf).where(u, u.const(math.nan))
|
||||
u = x.lt(math.inf).where(u, u.const_like(math.nan))
|
||||
# exp2(Inf) = Inf, exp2(-Inf) = 0, exp2(NaN) = NaN
|
||||
return _lazy_map_numbers(d, d.const(math.inf), d.const(0.0), d.const(math.nan), u)
|
||||
return _lazy_map_numbers(d, d.const_like(math.inf), d.const_like(0.0), d.const_like(math.nan), u)
|
||||
|
||||
def xlog2(d:UOp) -> UOp:
|
||||
"""
|
||||
@@ -271,7 +272,7 @@ def xlog2(d:UOp) -> UOp:
|
||||
"""
|
||||
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
fp64_p = d.dtype == dtypes.float64
|
||||
FLT_MIN = d.const(1e-6 if d.dtype == dtypes.float16 else 1e-4)
|
||||
FLT_MIN = d.const_like(1e-6 if d.dtype == dtypes.float16 else 1e-4)
|
||||
d_orig = d
|
||||
denormal_map = d.lt(FLT_MIN)
|
||||
for _ in range(8): d = denormal_map.where(d * (2 ** 8), d)
|
||||
@@ -283,26 +284,28 @@ def xlog2(d:UOp) -> UOp:
|
||||
if fp64_p:
|
||||
x = (m - 1.0) * (m + 1.0).recip()
|
||||
x2 = x * x
|
||||
t = polyN(x.const(0.2211941750456081490e+0), x2, [0.2200768693152277689e+0, 0.2623708057488514656e+0, 0.3205977477944495502e+0,
|
||||
0.4121985945485324709e+0, 0.5770780162997058982e+0, 0.96179669392608091449])
|
||||
s_hi, s_lo = dfadd2_f2_f2_f2(e, e.const(0), *dfmul2_f2_f2_f2(t.const(2.885390081777926774), t.const(0), x, x.const(0)))
|
||||
t = polyN(x.const_like(0.2211941750456081490e+0), x2, [0.2200768693152277689e+0, 0.2623708057488514656e+0, 0.3205977477944495502e+0,
|
||||
0.4121985945485324709e+0, 0.5770780162997058982e+0, 0.96179669392608091449])
|
||||
s_hi, s_lo = dfadd2_f2_f2_f2(e, e.const_like(0), *dfmul2_f2_f2_f2(t.const_like(2.885390081777926774), t.const_like(0), x, x.const_like(0)))
|
||||
r = t * (x * x2) + (s_hi + s_lo)
|
||||
else:
|
||||
xx, xy = dfdiv2_f2_f2_f2(*dfadd2_f2_f2_f2(m.const(-1), m.const(0), m, m.const(0)), *dfadd2_f2_f2_f2(m.const(1), m.const(0), m, m.const(0)))
|
||||
xx, xy = dfdiv2_f2_f2_f2(*dfadd2_f2_f2_f2(m.const_like(-1), m.const_like(0), m, m.const_like(0)),
|
||||
*dfadd2_f2_f2_f2(m.const_like(1), m.const_like(0), m, m.const_like(0)))
|
||||
x2 = xx * xx
|
||||
t = polyN(d.const(0.4374550283e+0), x2, [0.5764790177e+0, 0.9618012905120])
|
||||
sx, sy = dfadd2_f2_f2_f2(e, e.const(0), *dfmul2_f2_f2_f2(xx, xy, xx.const(2.8853900432586669922), xy.const(3.2734474483568488616e-08)))
|
||||
sx, sy = dfadd2_f2_f2_f2(sx, sy, x2.const(0), (x2 * xx) * t)
|
||||
t = polyN(d.const_like(0.4374550283e+0), x2, [0.5764790177e+0, 0.9618012905120])
|
||||
sx, sy = dfadd2_f2_f2_f2(e, e.const_like(0),
|
||||
*dfmul2_f2_f2_f2(xx, xy, xx.const_like(2.8853900432586669922), xy.const_like(3.2734474483568488616e-08)))
|
||||
sx, sy = dfadd2_f2_f2_f2(sx, sy, x2.const_like(0), (x2 * xx) * t)
|
||||
r = sx + sy
|
||||
# log2(Inf) = Inf
|
||||
r = d_orig.ne(math.inf).where(r, r.const(math.inf))
|
||||
r = d_orig.ne(math.inf).where(r, r.const_like(math.inf))
|
||||
# log2(x=-0.01) = NaN. where x < 0
|
||||
r = d_orig.lt(-0.0).where(r.const(math.nan), r)
|
||||
r = d_orig.lt(-0.0).where(r.const_like(math.nan), r)
|
||||
# log2(0) = -Inf, but we will compare using the value of y because 1e-200==0 is true.
|
||||
# log2_zero = the value of unmasked xlog2(0.0).
|
||||
log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79, None: -math.inf}[d.dtype]
|
||||
r = r.ne(log2_zero).where(r, r.const(-math.inf))
|
||||
r = r.ne(log2_zero).where(r, r.const_like(-math.inf))
|
||||
# log(NaN) = NaN, using for all real number x, either of x < Inf, x == Inf becomes True.
|
||||
r = d_orig.lt(math.inf).where(r, d_orig.ne(math.inf).where(d.const(math.nan), d))
|
||||
r = d_orig.lt(math.inf).where(r, d_orig.ne(math.inf).where(d.const_like(math.nan), d))
|
||||
# log(-0.0) = -Inf. In certain devices like PTX, x == -0.0 won't be true. so making reciprocal.
|
||||
return d_orig.recip().ne(-math.inf).where(r, r.const(-math.inf))
|
||||
return d_orig.recip().ne(-math.inf).where(r, r.const_like(-math.inf))
|
||||
|
||||
@@ -78,7 +78,7 @@ def fix_unfoldable_image_load(load:UOp, buf:UOp):
|
||||
if len(new_src) >= 4:
|
||||
new_src[2] = UOp(UOps.VECTORIZE, cast(DType, new_src[2].dtype).vec(4), tuple(new_src[2] for _ in range(4)))
|
||||
vec_load = UOp(UOps.LOAD, cast(DType, load.dtype).vec(4), tuple(new_src))
|
||||
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, UOp(UOps.GEP, load.dtype, (vec_load,), i)), range(4), load.const(float('nan')))
|
||||
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, UOp(UOps.GEP, load.dtype, (vec_load,), i)), range(4), load.const_like(float('nan')))
|
||||
|
||||
float4_folding = PatternMatcher([
|
||||
(UPat(UOps.EXPAND, src=UPat(UOps.LOAD, src=(UPat(name="buf"), UPat()), allow_any_len=True), name="ex"), fold_expanded),
|
||||
@@ -110,13 +110,13 @@ def mod_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
something_changed = True
|
||||
else: remainder.append(u)
|
||||
if not something_changed: return None
|
||||
return functools.reduce(operator.add, remainder)%c if remainder else x.const(0)
|
||||
return functools.reduce(operator.add, remainder)%c if remainder else x.const_like(0)
|
||||
|
||||
def div_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
# simplify x // c, None means no change
|
||||
|
||||
# simple cancel div case
|
||||
if 0 <= x.vmin.arg and x.vmax.arg < c: return x.const(0)
|
||||
if 0 <= x.vmin.arg and x.vmax.arg < c: return x.const_like(0)
|
||||
|
||||
quotient, remainder, rem_const, something_changed, gcd, divisor = [], [], 0, False, c, 1
|
||||
for u in _get_add_chain(x):
|
||||
@@ -136,9 +136,9 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
# handle the const
|
||||
if rem_const%c != rem_const:
|
||||
something_changed = True
|
||||
quotient.append(x.const(rem_const//c))
|
||||
quotient.append(x.const_like(rem_const//c))
|
||||
rem_const = rem_const%c
|
||||
if rem_const != 0: remainder.append(x.const(rem_const))
|
||||
if rem_const != 0: remainder.append(x.const_like(rem_const))
|
||||
|
||||
# x // c -> quotient + (remainder // div) // (c // div)
|
||||
div = gcd if gcd > 1 else divisor
|
||||
@@ -146,7 +146,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
if not something_changed: return newx//(c//div) if 1 < div < c and (newx:=div_folding(x, div)) is not None else None
|
||||
rem:Optional[UOp] = functools.reduce(operator.add, remainder) if remainder else None
|
||||
quo:Optional[UOp] = functools.reduce(operator.add, quotient) if quotient else None
|
||||
if quo is None: return x.const(0) if rem is None else cast(UOp, div_folding(rem, div))//(c//div)
|
||||
if quo is None: return x.const_like(0) if rem is None else cast(UOp, div_folding(rem, div))//(c//div)
|
||||
return quo if rem is None else cast(UOp, div_folding(rem, div))//(c//div)+quo
|
||||
|
||||
# ***** transcendental *****
|
||||
@@ -197,7 +197,7 @@ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, redu
|
||||
|
||||
def index_collapse(idx,rng,buf,add,mul,ld,reduce):
|
||||
if rng not in reduce.src: return None
|
||||
return UOp(reduce.op, reduce.dtype, (UOp(ld.op, ld.dtype, (buf, add+mul*idx, ld.const(0), idx.ge(rng.src[0]) & idx.lt(rng.src[1]))),)+
|
||||
return UOp(reduce.op, reduce.dtype, (UOp(ld.op, ld.dtype, (buf, add+mul*idx, ld.const_like(0), idx.ge(rng.src[0]) & idx.lt(rng.src[1]))),)+
|
||||
tuple(x for x in reduce.src[1:] if x is not rng), reduce.arg)
|
||||
|
||||
# this is symbolic 2.0
|
||||
@@ -251,23 +251,23 @@ constant_folder = PatternMatcher([
|
||||
# max folding
|
||||
(NOp.max(NOp.var('x'), NOp.var('y')), lambda x,y: x if x.vmin.arg >= y.vmax.arg else y if x.vmax.arg <= y.vmin.arg else None),
|
||||
# GEP/CAST const rules
|
||||
(NOp(UOps.GEP, src=(NOp.cvar("c"),), name="root"), lambda root, c: root.const(c.arg)),
|
||||
(UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: root.const(c.arg)),
|
||||
(NOp(UOps.GEP, src=(NOp.cvar("c"),), name="root"), lambda root, c: root.const_like(c.arg)),
|
||||
(UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: root.const_like(c.arg)),
|
||||
# a conditional with the same results either way is a noop, also fold const conditionals
|
||||
(NOp.var().where(NOp.var("val"), NOp.var("val")), lambda val: val),
|
||||
(NOp.cvar('gate').where(NOp.var('c0'), NOp.var('c1')), lambda gate, c0, c1: c0 if gate.arg else c1),
|
||||
# ** constant folding **
|
||||
(UPat(UOps.ALU, name="root", src=UPat(UOps.CONST)), lambda root: root.const(exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
|
||||
(UPat(UOps.ALU, name="root", src=UPat(UOps.CONST)), lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
|
||||
# ** self folding **
|
||||
# cast NOOP (NOTE: it's str to deal with PtrDType)
|
||||
(NOp(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
|
||||
(NOp(UOps.REDUCE, src=(NOp.var('x'),)), lambda x: x), # a REDUCE without ranges is a NOOP
|
||||
(NOp.var('x') + 0, lambda x: x), # x+0 -> x
|
||||
(NOp.var('x') * 1, lambda x: x), # x*1 -> x
|
||||
(NOp.var('x') // NOp.var('x'), lambda x: x.const(1)), # x//x -> 1
|
||||
(NOp.var('x') // NOp.var('x'), lambda x: x.const_like(1)), # x//x -> 1
|
||||
(NOp.var('x') // 1, lambda x: x), # x//1 -> x
|
||||
(NOp.var('x') // -1, lambda x: -x), # x//-1 -> -x
|
||||
(NOp.var('x') / NOp.var('x'), lambda x: x.const(1)), # x/x -> 1
|
||||
(NOp.var('x') / NOp.var('x'), lambda x: x.const_like(1)), # x/x -> 1
|
||||
((NOp.var("x") * NOp.var("x2")) / NOp.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
|
||||
(NOp.var('x', dtype=dtypes.bool) & NOp.cvar('c'), lambda x,c: x if c.arg else c),
|
||||
(NOp.var('x', dtype=dtypes.bool) | NOp.cvar('c'), lambda x,c: c if c.arg else x),
|
||||
@@ -275,9 +275,9 @@ constant_folder = PatternMatcher([
|
||||
# x*0 -> 0 or 0*x -> 0
|
||||
# if x is nan or inf it should render the nan value.
|
||||
# NOTE: this can be wrong for loaded NaN
|
||||
(NOp.var('x') * 0, lambda x: x.const(float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
||||
(NOp.var('x') * 0, lambda x: x.const_like(float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
||||
# min==max -> CONST (slow!)
|
||||
(UPat({UOps.ALU, UOps.DEFINE_VAR}, name='x'), lambda x: x.const(x.vmin.arg) if x.vmin.arg == x.vmax.arg else None),
|
||||
(UPat({UOps.ALU, UOps.DEFINE_VAR}, name='x'), lambda x: x.const_like(x.vmin.arg) if x.vmin.arg == x.vmax.arg else None),
|
||||
# ** load/store folding **
|
||||
(NOp.store(NOp.var("buf"), NOp.var("idx"), NOp.load(NOp.var("buf"), NOp.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)),
|
||||
# ** two stage add/mul folding **
|
||||
@@ -381,7 +381,8 @@ def do_reduce(root:UOp):
|
||||
ret = root.src[0]
|
||||
if len(reduce_parented):
|
||||
assert root.dtype is not None
|
||||
acc = UOp(UOps.DEFINE_ACC, root.dtype, (root.const(identity_element(root.arg, root.dtype.scalar())),) + tuple(reduce_parented), (acc_number,))
|
||||
acc = UOp(UOps.DEFINE_ACC, root.dtype,
|
||||
(root.const_like(identity_element(root.arg, root.dtype.scalar())),) + tuple(reduce_parented), (acc_number,))
|
||||
acc_number += 1
|
||||
ret = UOp(UOps.PHI, root.dtype, (acc, acc.alu(root.arg, ret)))
|
||||
# for MAX, we can just ignore the unparented
|
||||
|
||||
@@ -263,6 +263,7 @@ class UOp:
|
||||
dtype: Optional[DType] = None
|
||||
src: Tuple[UOp, ...] = tuple()
|
||||
arg: Any = None
|
||||
def __hash__(self): return id(self)
|
||||
def commutative(self) -> bool:
|
||||
return (self.op is UOps.ALU and \
|
||||
self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR})
|
||||
@@ -286,7 +287,7 @@ class UOp:
|
||||
assert ret.op is UOps.SHAPETRACKER, f"st_arg trying to return {ret}"
|
||||
return ret.arg
|
||||
def sink(self, *srcs): return UOp(UOps.SINK, None, (self,)+srcs)
|
||||
def ufix(self, x): return self.const(x) if not isinstance(x, UOp) else x
|
||||
def ufix(self, x): return self.const_like(x) if not isinstance(x, UOp) else x
|
||||
def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,))
|
||||
def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,))
|
||||
def gep(self, i:int): return type(self)(UOps.GEP, self.dtype.scalar() if self.dtype is not None else None, (self,), i)
|
||||
@@ -310,17 +311,18 @@ class UOp:
|
||||
def min(self, x): return -(-self).max(-x)
|
||||
def where(self, x, y): return self.alu(TernaryOps.WHERE, x, y)
|
||||
def recip(self): return self.alu(UnaryOps.RECIP)
|
||||
def const(self:Union[UOp, DType, None], b:ConstType|Variable): return UOp._const(self.dtype if isinstance(self, UOp) else self, b)
|
||||
def sconst(self:Union[UOp, DType, None], b:ConstType|Variable):
|
||||
return UOp._const(cast(DType, self.dtype if isinstance(self, UOp) else self).scalar() if self is not None else self, b)
|
||||
def const_like(self, b:ConstType|Variable): return type(self).const(self.dtype, b)
|
||||
def sconst_like(self, b:ConstType|Variable): return type(self).const(self.dtype.scalar() if self.dtype is not None else None, b)
|
||||
@classmethod
|
||||
@functools.lru_cache(None)
|
||||
def const(cls, dtype:Optional[DType], b:ConstType|Variable): return UOp._const(cls, dtype, b)
|
||||
@staticmethod
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _const(dtype:Optional[DType], b:ConstType|Variable):
|
||||
def _const(typ, dtype:Optional[DType], b:ConstType|Variable):
|
||||
# TODO: fix dtype of b.max after Variable is just an UOp
|
||||
if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (UOp.const(dtypes.int, b.min), UOp.const(dtypes.int, cast(int,b.max))), b)
|
||||
if isinstance(b, Variable): return typ(UOps.DEFINE_VAR, dtype, (typ.const(dtypes.int, b.min), typ.const(dtypes.int, cast(int,b.max))), b)
|
||||
if dtype is not None and dtype != (sdtype := dtype.scalar()):
|
||||
return UOp(UOps.VECTORIZE, dtype, src=tuple(UOp(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count)))
|
||||
return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
|
||||
return typ(UOps.VECTORIZE, dtype, src=tuple(typ(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count)))
|
||||
return typ(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
|
||||
def alu(self, arg, *src:UOp):
|
||||
return type(self)(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else (self, *src)[-1].dtype, (self,)+src, arg)
|
||||
@staticmethod
|
||||
@@ -349,7 +351,7 @@ class UOp:
|
||||
return 1
|
||||
def divides(self, v) -> Optional[UOp]:
|
||||
if v==1: return self
|
||||
if self.op is UOps.CONST: return self.const(self.arg//v) if self.arg%v == 0 else None
|
||||
if self.op is UOps.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None
|
||||
if self.op is UOps.ALU:
|
||||
if self.arg is BinaryOps.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
|
||||
if self.arg is BinaryOps.MUL:
|
||||
@@ -357,32 +359,34 @@ class UOp:
|
||||
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
|
||||
return None # generic None if we aren't sure
|
||||
@property
|
||||
def vmin(self) -> UOp: return x if (x:=self._min_max[0]) is not None and not math.isnan(x.arg) else self.sconst(dtypes.min(cast(DType, self.dtype)))
|
||||
def vmin(self) -> UOp:
|
||||
return x if (x:=self._min_max[0]) is not None and not math.isnan(x.arg) else self.sconst_like(dtypes.min(cast(DType, self.dtype)))
|
||||
@property
|
||||
def vmax(self) -> UOp: return x if (x:=self._min_max[1]) is not None and not math.isnan(x.arg) else self.sconst(dtypes.max(cast(DType, self.dtype)))
|
||||
def vmax(self) -> UOp:
|
||||
return x if (x:=self._min_max[1]) is not None and not math.isnan(x.arg) else self.sconst_like(dtypes.max(cast(DType, self.dtype)))
|
||||
@functools.cached_property
|
||||
def _min_max(self) -> Tuple[Optional[UOp], Optional[UOp]]:
|
||||
# NOTE: returned UOp is assumed to be CONST
|
||||
if self.op is UOps.DEFINE_VAR and self.src: return self.src[0], self.src[1] if isinstance(self.src[1].arg, int) else None
|
||||
if self.op is UOps.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax
|
||||
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
|
||||
if self.op is UOps.SPECIAL: return self.const(0), self.const(self.arg[1]-1) if isinstance(self.arg[1], int) else None
|
||||
if self.op is UOps.SPECIAL: return self.const_like(0), self.const_like(self.arg[1]-1) if isinstance(self.arg[1], int) else None
|
||||
if self.op is UOps.CONST: return self, self
|
||||
if self.op is UOps.ALU and cast(DType, self.dtype).count == 1:
|
||||
s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)]
|
||||
if self.arg is BinaryOps.ADD: return self.sconst(s0.vmin.arg+s1.vmin.arg), self.sconst(s0.vmax.arg+s1.vmax.arg)
|
||||
if self.arg is BinaryOps.ADD: return self.sconst_like(s0.vmin.arg+s1.vmin.arg), self.sconst_like(s0.vmax.arg+s1.vmax.arg)
|
||||
if self.arg is BinaryOps.MUL and (s0.vmin.arg >= 0 or s1.vmin.arg >= 0):
|
||||
# handle at lease one is non-negative
|
||||
Lmin, Lmax = (s0.vmin.arg, s0.vmax.arg) if s1.vmin.arg >= 0 else (s0.vmax.arg, s0.vmin.arg)
|
||||
Rmin, Rmax = (s1.vmin.arg, s1.vmax.arg) if s0.vmin.arg >= 0 else (s1.vmax.arg, s1.vmin.arg)
|
||||
assert math.isnan(Lmax*Rmax) or math.isnan(Lmin*Rmin) or Lmax*Rmax >= Lmin*Rmin, f"{Lmax=}, {Lmin=}, {Rmax=}, {Rmin=}"
|
||||
return self.sconst(Lmin*Rmin), self.sconst(Lmax*Rmax)
|
||||
if self.arg is BinaryOps.MOD and s1.vmin.arg > 0: return self.sconst(0), self.sconst(s1.vmax.arg-1)
|
||||
return self.sconst_like(Lmin*Rmin), self.sconst_like(Lmax*Rmax)
|
||||
if self.arg is BinaryOps.MOD and s1.vmin.arg > 0: return self.sconst_like(0), self.sconst_like(s1.vmax.arg-1)
|
||||
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST:
|
||||
if s1.arg > 0: return self.sconst(s0.vmin.arg//s1.arg), self.sconst(s0.vmax.arg//s1.arg)
|
||||
if s1.arg < 0: return self.sconst(-(s0.vmax.arg//-s1.arg)), self.sconst(-(s0.vmin.arg//-s1.arg))
|
||||
if self.arg is BinaryOps.MAX: return self.sconst(max(s0.vmin.arg, s1.vmin.arg)), self.sconst(max(s0.vmax.arg, s1.vmax.arg))
|
||||
if self.arg is BinaryOps.CMPLT: return (UOp.sconst(dtypes.bool, s0.vmax.arg<s1.vmin.arg), UOp.sconst(dtypes.bool, s0.vmin.arg<s1.vmax.arg))
|
||||
if s1.arg > 0: return self.sconst_like(s0.vmin.arg//s1.arg), self.sconst_like(s0.vmax.arg//s1.arg)
|
||||
if s1.arg < 0: return self.sconst_like(-(s0.vmax.arg//-s1.arg)), self.sconst_like(-(s0.vmin.arg//-s1.arg))
|
||||
if self.arg is BinaryOps.MAX: return self.sconst_like(max(s0.vmin.arg, s1.vmin.arg)), self.sconst_like(max(s0.vmax.arg, s1.vmax.arg))
|
||||
if self.arg is BinaryOps.CMPLT: return (UOp.const(dtypes.bool, s0.vmax.arg<s1.vmin.arg), UOp.const(dtypes.bool, s0.vmin.arg<s1.vmax.arg))
|
||||
return None, None
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -399,7 +403,7 @@ def get_location() -> Tuple[str, int]:
|
||||
while (frm.f_code.co_filename.endswith("/ops.py") or frm.f_code.co_filename == '<string>') and frm.f_back is not None: frm = frm.f_back
|
||||
return frm.f_code.co_filename, frm.f_lineno
|
||||
@functools.lru_cache(None)
|
||||
def lines(fn): return open(fn).readlines()
|
||||
def lines(fn) -> List[str]: return open(fn).readlines()
|
||||
|
||||
@dataclass(frozen=True, repr=False) # reuse repr from UOp
|
||||
class NOp(UOp):
|
||||
@@ -414,7 +418,11 @@ class NOp(UOp):
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.CONST, dtype=dtype, name=name)
|
||||
def const(self:Union[UOp, DType, None], b:ConstType|Variable): return NOp((x:=UOp.const(self, b)).op, x.dtype, x.src, x.arg)
|
||||
|
||||
# this is needed so NOp has a different cache
|
||||
@classmethod
|
||||
@functools.lru_cache(None)
|
||||
def const(cls, dtype:Optional[DType], b:ConstType|Variable): return cls._const(cls, dtype, b)
|
||||
|
||||
@functools.cached_property
|
||||
def upat(self:NOp) -> UPat:
|
||||
@@ -447,7 +455,11 @@ class UPat:
|
||||
upat_match = [self.in_src] if isinstance(self.in_src, UPat) else ([] if self.in_src is None else self.src[0])
|
||||
self.early_reject = set((pp.op[0], pp.arg) for pp in upat_match if pp.op is not None and len(pp.op) == 1)
|
||||
|
||||
def printable(self:UPat): return lines(self.location[0])[self.location[1]-1].strip()
|
||||
def printable(self:UPat) -> str:
|
||||
try:
|
||||
return lines(self.location[0])[self.location[1]-1].strip()
|
||||
except FileNotFoundError:
|
||||
return "<missing>"
|
||||
def __repr__(self):
|
||||
def rep(x):
|
||||
form = "UPat(%s, %s, name=%s, dtype=%s, allow_any_len=%s, src=%s)"
|
||||
|
||||
@@ -66,7 +66,7 @@ ptx_matcher = PatternMatcher([
|
||||
UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))),
|
||||
lambda root, alu, const: UOp(root.op, root.dtype,
|
||||
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
|
||||
const.const(root.src[0].dtype.itemsize)*const)+root.src[2:])),
|
||||
const*root.src[0].dtype.itemsize)+root.src[2:])),
|
||||
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_any_len=True, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
||||
UPat(UOps.CONST, name="const"))),
|
||||
lambda root, const: UOp(root.op, root.dtype,
|
||||
|
||||
Reference in New Issue
Block a user