diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 2daa69da68..70a7c49dbd 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -13,7 +13,7 @@ simple_pm = PatternMatcher([ (NOp.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)), (NOp.cvar('x') + NOp.cvar('y'), lambda x,y: UOp.const(dtypes.float, x.arg+y.arg)), (NOp.cvar('x') * NOp.cvar('y') * NOp.cvar('z'), lambda x,y,z: UOp.const(dtypes.float, x.arg*y.arg*z.arg)), - ((NOp.var('x') + NOp.cvar('c1')) + NOp.cvar('c2'), lambda x,c1,c2: x + x.const(c1.arg+c2.arg)), + ((NOp.var('x') + NOp.cvar('c1')) + NOp.cvar('c2'), lambda x,c1,c2: x + (c1.arg+c2.arg)), ]) def to_uops_list(u:List[UOp]) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u))) @@ -134,7 +134,7 @@ class TestGraphRewrite(unittest.TestCase): b = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('b', 0, 1)) c = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('c', 0, 1)) d = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('d', 0, 1)) - outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b] + outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const_like(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b] for out in outs: sink = graph_rewrite(out, constant_folder) print(sink) diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index f1a1e3f645..e42f8d3e36 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -93,16 +93,16 @@ class IndependentLowerer: idx, valid = x.st_arg.to_indexed_uops(self.ridxs if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL else self.idxs) # TODO: check has_valid in UPat, not here has_valid = valid.op is not UOps.CONST or valid.arg is not True - if x.op is UOps.CONST: return valid.where(x.const(x.arg), x.const(0)) + if x.op is UOps.CONST: return valid.where(x.const_like(x.arg), x.const_like(0)) buf = x.src[0] if x.op is UOps.LOAD: barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[2]),)),) if x.src[0].op is UOps.DEFINE_LOCAL else () - return UOp(UOps.LOAD, x.dtype, (buf, idx) + ((x.const(0), valid) if has_valid else ()) + barrier) + return UOp(UOps.LOAD, x.dtype, (buf, idx) + ((x.const_like(0), valid) if has_valid else ()) + barrier) # NOTE: only store the local reduceop in the threads that are actually doing the reduce store_back = x.src[0].op is UOps.DEFINE_LOCAL and x.src[2].op is UOps.REDUCE_AXIS and \ x.src[2].src[0].op is UOps.LOAD and x.src[2].src[0].src[0].op is UOps.DEFINE_LOCAL # NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes - if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const(0) if i in x.src[2].arg[1] else u for i,u in enumerate(self.idxs)]) + if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if i in x.src[2].arg[1] else u for i,u in enumerate(self.idxs)]) if x.src[0].op is UOps.DEFINE_GLOBAL or store_back: for oidx, ridx in zip(self.idxs, self.ridxs): if oidx != ridx: valid = valid * oidx.eq(0) diff --git a/tinygrad/codegen/transcendental.py b/tinygrad/codegen/transcendental.py index dc433f0358..a2fb6a3c5b 100644 --- a/tinygrad/codegen/transcendental.py +++ b/tinygrad/codegen/transcendental.py @@ -38,7 +38,7 @@ def rintk(d:UOp) -> UOp: """ceiling(d:float) -> int""" assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES return_t = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype] - return (d + d.lt(0.0).where(d.const(-0.5), d.const(0.5))).cast(return_t) + return (d + d.lt(0.0).where(d.const_like(-0.5), d.const_like(0.5))).cast(return_t) def pow2if(q:UOp, float_dtype:DType): """cast(2^q, float_dtype) where q is any integer in the range of [-126, 127]""" @@ -82,7 +82,7 @@ def frexp(v:UOp) -> Tuple[UOp, UOp]: result_f = bits_to_float((bits & m1) | m2, v.dtype) value = exponent_zero.where(result_f, v) exp = exponent + (-bias) - exp = exponent_zero.where(exp, exp.const(0)) + exp = exponent_zero.where(exp, exp.const_like(0)) if v.dtype == dtypes.float16: exp = exp.bitcast(dtypes.int16) return value, exp @@ -114,7 +114,7 @@ def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]: def _take(an:UOp, offset:int, count:int=0) -> UOp: """an = two_over_pi_f[i+offset]""" if count+offset <= len(two_over_pi_f[0:-2]): - an = _eq(i, count).where(_take(an, offset, count=count+1), an.const(two_over_pi_f[count+offset])) + an = _eq(i, count).where(_take(an, offset, count=count+1), an.const_like(two_over_pi_f[count+offset])) return an def _exact_pow2if(x): return pow2if(x, input_dtype).cast(acc_dtype) def _shl_lazy(x, y): return (x.cast(acc_dtype) * _exact_pow2if(y)).cast(dtypes.uint32) @@ -183,7 +183,7 @@ def trig_poly(d:UOp, coeff32, coeff64): u = __poly8(s, s2, s4, *coeff64[:-1]) u = u * s + coeff64[-1] else: - u = polyN(s.const(coeff32[0]), s, coeff32[1:]) + u = polyN(s.const_like(coeff32[0]), s, coeff32[1:]) return s * (u * d) + d # approximate sine on [-pi/2, pi/2] def sin_poly(d:UOp) -> UOp: @@ -195,13 +195,13 @@ def sin_poly(d:UOp) -> UOp: def sin_poly_small(d:UOp, q:UOp) -> UOp: def _ifand(n:int): return (q & n).ne(0) r = sin_poly(d) - return r * _ifand(1).where(r.const(-1), r.const(1)) + return r * _ifand(1).where(r.const_like(-1), r.const_like(1)) def sin_poly_large(d:UOp, q:UOp) -> UOp: def _ifand(n:int): return (q & n).ne(0) - d = d + _ifand(1).where(d.const(math.pi / 2), d.const(0)) + d = d + _ifand(1).where(d.const_like(math.pi / 2), d.const_like(0)) r = sin_poly(d) - return r * _ifand(2).where(r.const(-1), r.const(1)) + return r * _ifand(2).where(r.const_like(-1), r.const_like(1)) # *** toplevel functions for xsin/xlog2/xexp2 *** @@ -214,9 +214,9 @@ def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp: assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES reduction_algo = cody_waite_reduction if fast else payne_hanek_reduction # mask +-inf/nan as zero - x = _lazy_map_numbers(d, d.const(0.0), d.const(0.0), d.const(0.0), d) + x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d) # x_sign = sign(x) - x_sign = x.ne(0).where(x.lt(0).where(x.const(-1), x.const(1)), x.const(0)) + x_sign = x.ne(0).where(x.lt(0).where(x.const_like(-1), x.const_like(1)), x.const_like(0)) x_abs = x * x_sign r, q = reduction_algo(x_abs) if fast: result = sin_poly_small(r, q) @@ -229,7 +229,7 @@ def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp: result = switch_over_map.where(sin_poly_small(r, q), sin_poly_large(r, q)) result = result * x_sign # adjusts the sign for abs(x). # sin(Inf) = NaN, sin(-Inf) = NaN, sin(NaN) = NaN - return _lazy_map_numbers(d, d.const(math.nan), d.const(math.nan), d.const(math.nan), result) + return _lazy_map_numbers(d, d.const_like(math.nan), d.const_like(math.nan), d.const_like(math.nan), result) def xexp2(d:UOp) -> UOp: """ @@ -239,30 +239,31 @@ def xexp2(d:UOp) -> UOp: assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES fp64_p = d.dtype == dtypes.float64 # mask +=inf/nan as zero. - x = _lazy_map_numbers(d, d.const(0.0), d.const(0.0), d.const(0.0), d) + x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d) q = rintk(x) # s = d - round(d) s = x - q.cast(x.dtype) # a polynomial approximation with 13 non-zero terms in the range of [−(log 2)/2,(log 2)/2]. if fp64_p: - u = polyN(s.const(0.4434359082926529454e-9), s, [0.7073164598085707425e-8, 0.1017819260921760451e-6, 0.1321543872511327615e-5, - 0.1525273353517584730e-4, 0.1540353045101147808e-3, 0.1333355814670499073e-2, - 0.9618129107597600536e-2, 0.5550410866482046596e-1, 0.2402265069591012214e+0, - 0.6931471805599452862e+0, 0.1000000000000000000e+1]) + u = polyN(s.const_like(0.4434359082926529454e-9), s, [0.7073164598085707425e-8, 0.1017819260921760451e-6, 0.1321543872511327615e-5, + 0.1525273353517584730e-4, 0.1540353045101147808e-3, 0.1333355814670499073e-2, + 0.9618129107597600536e-2, 0.5550410866482046596e-1, 0.2402265069591012214e+0, + 0.6931471805599452862e+0, 0.1000000000000000000e+1]) else: - u = polyN(s.const(0.1535920892e-3), s, [0.1339262701e-2, 0.9618384764e-2, 0.5550347269e-1, 0.2402264476e+0, 0.6931471825e+0, 0.1000000000e+1]) + u = polyN(s.const_like(0.1535920892e-3), s, + [0.1339262701e-2, 0.9618384764e-2, 0.5550347269e-1, 0.2402264476e+0, 0.6931471825e+0, 0.1000000000e+1]) u = ldexp2k(u, q) # u*2^q upper = {dtypes.float64: 1024, dtypes.float32: 128, dtypes.float16: 23.0}[x.dtype] lower = {dtypes.float64: -2000, dtypes.float32: -150, dtypes.float16: -22}[x.dtype] # Replace x >= upper with +inf - u = x.ne(upper).where(u, x.const(math.inf)) - u = x.lt(upper).where(u, x.const(math.inf)) + u = x.ne(upper).where(u, x.const_like(math.inf)) + u = x.lt(upper).where(u, x.const_like(math.inf)) # Replace x <= lower with zero. - u = x.lt(lower).where(x.const(0.0), u) + u = x.lt(lower).where(x.const_like(0.0), u) # x=NaN never satisfies x < Inf. (for fastmode) - u = x.lt(math.inf).where(u, u.const(math.nan)) + u = x.lt(math.inf).where(u, u.const_like(math.nan)) # exp2(Inf) = Inf, exp2(-Inf) = 0, exp2(NaN) = NaN - return _lazy_map_numbers(d, d.const(math.inf), d.const(0.0), d.const(math.nan), u) + return _lazy_map_numbers(d, d.const_like(math.inf), d.const_like(0.0), d.const_like(math.nan), u) def xlog2(d:UOp) -> UOp: """ @@ -271,7 +272,7 @@ def xlog2(d:UOp) -> UOp: """ assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES fp64_p = d.dtype == dtypes.float64 - FLT_MIN = d.const(1e-6 if d.dtype == dtypes.float16 else 1e-4) + FLT_MIN = d.const_like(1e-6 if d.dtype == dtypes.float16 else 1e-4) d_orig = d denormal_map = d.lt(FLT_MIN) for _ in range(8): d = denormal_map.where(d * (2 ** 8), d) @@ -283,26 +284,28 @@ def xlog2(d:UOp) -> UOp: if fp64_p: x = (m - 1.0) * (m + 1.0).recip() x2 = x * x - t = polyN(x.const(0.2211941750456081490e+0), x2, [0.2200768693152277689e+0, 0.2623708057488514656e+0, 0.3205977477944495502e+0, - 0.4121985945485324709e+0, 0.5770780162997058982e+0, 0.96179669392608091449]) - s_hi, s_lo = dfadd2_f2_f2_f2(e, e.const(0), *dfmul2_f2_f2_f2(t.const(2.885390081777926774), t.const(0), x, x.const(0))) + t = polyN(x.const_like(0.2211941750456081490e+0), x2, [0.2200768693152277689e+0, 0.2623708057488514656e+0, 0.3205977477944495502e+0, + 0.4121985945485324709e+0, 0.5770780162997058982e+0, 0.96179669392608091449]) + s_hi, s_lo = dfadd2_f2_f2_f2(e, e.const_like(0), *dfmul2_f2_f2_f2(t.const_like(2.885390081777926774), t.const_like(0), x, x.const_like(0))) r = t * (x * x2) + (s_hi + s_lo) else: - xx, xy = dfdiv2_f2_f2_f2(*dfadd2_f2_f2_f2(m.const(-1), m.const(0), m, m.const(0)), *dfadd2_f2_f2_f2(m.const(1), m.const(0), m, m.const(0))) + xx, xy = dfdiv2_f2_f2_f2(*dfadd2_f2_f2_f2(m.const_like(-1), m.const_like(0), m, m.const_like(0)), + *dfadd2_f2_f2_f2(m.const_like(1), m.const_like(0), m, m.const_like(0))) x2 = xx * xx - t = polyN(d.const(0.4374550283e+0), x2, [0.5764790177e+0, 0.9618012905120]) - sx, sy = dfadd2_f2_f2_f2(e, e.const(0), *dfmul2_f2_f2_f2(xx, xy, xx.const(2.8853900432586669922), xy.const(3.2734474483568488616e-08))) - sx, sy = dfadd2_f2_f2_f2(sx, sy, x2.const(0), (x2 * xx) * t) + t = polyN(d.const_like(0.4374550283e+0), x2, [0.5764790177e+0, 0.9618012905120]) + sx, sy = dfadd2_f2_f2_f2(e, e.const_like(0), + *dfmul2_f2_f2_f2(xx, xy, xx.const_like(2.8853900432586669922), xy.const_like(3.2734474483568488616e-08))) + sx, sy = dfadd2_f2_f2_f2(sx, sy, x2.const_like(0), (x2 * xx) * t) r = sx + sy # log2(Inf) = Inf - r = d_orig.ne(math.inf).where(r, r.const(math.inf)) + r = d_orig.ne(math.inf).where(r, r.const_like(math.inf)) # log2(x=-0.01) = NaN. where x < 0 - r = d_orig.lt(-0.0).where(r.const(math.nan), r) + r = d_orig.lt(-0.0).where(r.const_like(math.nan), r) # log2(0) = -Inf, but we will compare using the value of y because 1e-200==0 is true. # log2_zero = the value of unmasked xlog2(0.0). log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79, None: -math.inf}[d.dtype] - r = r.ne(log2_zero).where(r, r.const(-math.inf)) + r = r.ne(log2_zero).where(r, r.const_like(-math.inf)) # log(NaN) = NaN, using for all real number x, either of x < Inf, x == Inf becomes True. - r = d_orig.lt(math.inf).where(r, d_orig.ne(math.inf).where(d.const(math.nan), d)) + r = d_orig.lt(math.inf).where(r, d_orig.ne(math.inf).where(d.const_like(math.nan), d)) # log(-0.0) = -Inf. In certain devices like PTX, x == -0.0 won't be true. so making reciprocal. - return d_orig.recip().ne(-math.inf).where(r, r.const(-math.inf)) + return d_orig.recip().ne(-math.inf).where(r, r.const_like(-math.inf)) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 7dea1e3dc4..12035e13ec 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -78,7 +78,7 @@ def fix_unfoldable_image_load(load:UOp, buf:UOp): if len(new_src) >= 4: new_src[2] = UOp(UOps.VECTORIZE, cast(DType, new_src[2].dtype).vec(4), tuple(new_src[2] for _ in range(4))) vec_load = UOp(UOps.LOAD, cast(DType, load.dtype).vec(4), tuple(new_src)) - return functools.reduce(lambda ret, i: id4.ne(i).where(ret, UOp(UOps.GEP, load.dtype, (vec_load,), i)), range(4), load.const(float('nan'))) + return functools.reduce(lambda ret, i: id4.ne(i).where(ret, UOp(UOps.GEP, load.dtype, (vec_load,), i)), range(4), load.const_like(float('nan'))) float4_folding = PatternMatcher([ (UPat(UOps.EXPAND, src=UPat(UOps.LOAD, src=(UPat(name="buf"), UPat()), allow_any_len=True), name="ex"), fold_expanded), @@ -110,13 +110,13 @@ def mod_folding(x:UOp, c:int) -> Optional[UOp]: something_changed = True else: remainder.append(u) if not something_changed: return None - return functools.reduce(operator.add, remainder)%c if remainder else x.const(0) + return functools.reduce(operator.add, remainder)%c if remainder else x.const_like(0) def div_folding(x:UOp, c:int) -> Optional[UOp]: # simplify x // c, None means no change # simple cancel div case - if 0 <= x.vmin.arg and x.vmax.arg < c: return x.const(0) + if 0 <= x.vmin.arg and x.vmax.arg < c: return x.const_like(0) quotient, remainder, rem_const, something_changed, gcd, divisor = [], [], 0, False, c, 1 for u in _get_add_chain(x): @@ -136,9 +136,9 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]: # handle the const if rem_const%c != rem_const: something_changed = True - quotient.append(x.const(rem_const//c)) + quotient.append(x.const_like(rem_const//c)) rem_const = rem_const%c - if rem_const != 0: remainder.append(x.const(rem_const)) + if rem_const != 0: remainder.append(x.const_like(rem_const)) # x // c -> quotient + (remainder // div) // (c // div) div = gcd if gcd > 1 else divisor @@ -146,7 +146,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]: if not something_changed: return newx//(c//div) if 1 < div < c and (newx:=div_folding(x, div)) is not None else None rem:Optional[UOp] = functools.reduce(operator.add, remainder) if remainder else None quo:Optional[UOp] = functools.reduce(operator.add, quotient) if quotient else None - if quo is None: return x.const(0) if rem is None else cast(UOp, div_folding(rem, div))//(c//div) + if quo is None: return x.const_like(0) if rem is None else cast(UOp, div_folding(rem, div))//(c//div) return quo if rem is None else cast(UOp, div_folding(rem, div))//(c//div)+quo # ***** transcendental ***** @@ -197,7 +197,7 @@ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, redu def index_collapse(idx,rng,buf,add,mul,ld,reduce): if rng not in reduce.src: return None - return UOp(reduce.op, reduce.dtype, (UOp(ld.op, ld.dtype, (buf, add+mul*idx, ld.const(0), idx.ge(rng.src[0]) & idx.lt(rng.src[1]))),)+ + return UOp(reduce.op, reduce.dtype, (UOp(ld.op, ld.dtype, (buf, add+mul*idx, ld.const_like(0), idx.ge(rng.src[0]) & idx.lt(rng.src[1]))),)+ tuple(x for x in reduce.src[1:] if x is not rng), reduce.arg) # this is symbolic 2.0 @@ -251,23 +251,23 @@ constant_folder = PatternMatcher([ # max folding (NOp.max(NOp.var('x'), NOp.var('y')), lambda x,y: x if x.vmin.arg >= y.vmax.arg else y if x.vmax.arg <= y.vmin.arg else None), # GEP/CAST const rules - (NOp(UOps.GEP, src=(NOp.cvar("c"),), name="root"), lambda root, c: root.const(c.arg)), - (UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: root.const(c.arg)), + (NOp(UOps.GEP, src=(NOp.cvar("c"),), name="root"), lambda root, c: root.const_like(c.arg)), + (UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: root.const_like(c.arg)), # a conditional with the same results either way is a noop, also fold const conditionals (NOp.var().where(NOp.var("val"), NOp.var("val")), lambda val: val), (NOp.cvar('gate').where(NOp.var('c0'), NOp.var('c1')), lambda gate, c0, c1: c0 if gate.arg else c1), # ** constant folding ** - (UPat(UOps.ALU, name="root", src=UPat(UOps.CONST)), lambda root: root.const(exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))), + (UPat(UOps.ALU, name="root", src=UPat(UOps.CONST)), lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))), # ** self folding ** # cast NOOP (NOTE: it's str to deal with PtrDType) (NOp(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None), (NOp(UOps.REDUCE, src=(NOp.var('x'),)), lambda x: x), # a REDUCE without ranges is a NOOP (NOp.var('x') + 0, lambda x: x), # x+0 -> x (NOp.var('x') * 1, lambda x: x), # x*1 -> x - (NOp.var('x') // NOp.var('x'), lambda x: x.const(1)), # x//x -> 1 + (NOp.var('x') // NOp.var('x'), lambda x: x.const_like(1)), # x//x -> 1 (NOp.var('x') // 1, lambda x: x), # x//1 -> x (NOp.var('x') // -1, lambda x: -x), # x//-1 -> -x - (NOp.var('x') / NOp.var('x'), lambda x: x.const(1)), # x/x -> 1 + (NOp.var('x') / NOp.var('x'), lambda x: x.const_like(1)), # x/x -> 1 ((NOp.var("x") * NOp.var("x2")) / NOp.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x (NOp.var('x', dtype=dtypes.bool) & NOp.cvar('c'), lambda x,c: x if c.arg else c), (NOp.var('x', dtype=dtypes.bool) | NOp.cvar('c'), lambda x,c: c if c.arg else x), @@ -275,9 +275,9 @@ constant_folder = PatternMatcher([ # x*0 -> 0 or 0*x -> 0 # if x is nan or inf it should render the nan value. # NOTE: this can be wrong for loaded NaN - (NOp.var('x') * 0, lambda x: x.const(float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)), + (NOp.var('x') * 0, lambda x: x.const_like(float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)), # min==max -> CONST (slow!) - (UPat({UOps.ALU, UOps.DEFINE_VAR}, name='x'), lambda x: x.const(x.vmin.arg) if x.vmin.arg == x.vmax.arg else None), + (UPat({UOps.ALU, UOps.DEFINE_VAR}, name='x'), lambda x: x.const_like(x.vmin.arg) if x.vmin.arg == x.vmax.arg else None), # ** load/store folding ** (NOp.store(NOp.var("buf"), NOp.var("idx"), NOp.load(NOp.var("buf"), NOp.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)), # ** two stage add/mul folding ** @@ -381,7 +381,8 @@ def do_reduce(root:UOp): ret = root.src[0] if len(reduce_parented): assert root.dtype is not None - acc = UOp(UOps.DEFINE_ACC, root.dtype, (root.const(identity_element(root.arg, root.dtype.scalar())),) + tuple(reduce_parented), (acc_number,)) + acc = UOp(UOps.DEFINE_ACC, root.dtype, + (root.const_like(identity_element(root.arg, root.dtype.scalar())),) + tuple(reduce_parented), (acc_number,)) acc_number += 1 ret = UOp(UOps.PHI, root.dtype, (acc, acc.alu(root.arg, ret))) # for MAX, we can just ignore the unparented diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 682f761616..578dff118a 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -263,6 +263,7 @@ class UOp: dtype: Optional[DType] = None src: Tuple[UOp, ...] = tuple() arg: Any = None + def __hash__(self): return id(self) def commutative(self) -> bool: return (self.op is UOps.ALU and \ self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR}) @@ -286,7 +287,7 @@ class UOp: assert ret.op is UOps.SHAPETRACKER, f"st_arg trying to return {ret}" return ret.arg def sink(self, *srcs): return UOp(UOps.SINK, None, (self,)+srcs) - def ufix(self, x): return self.const(x) if not isinstance(x, UOp) else x + def ufix(self, x): return self.const_like(x) if not isinstance(x, UOp) else x def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,)) def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,)) def gep(self, i:int): return type(self)(UOps.GEP, self.dtype.scalar() if self.dtype is not None else None, (self,), i) @@ -310,17 +311,18 @@ class UOp: def min(self, x): return -(-self).max(-x) def where(self, x, y): return self.alu(TernaryOps.WHERE, x, y) def recip(self): return self.alu(UnaryOps.RECIP) - def const(self:Union[UOp, DType, None], b:ConstType|Variable): return UOp._const(self.dtype if isinstance(self, UOp) else self, b) - def sconst(self:Union[UOp, DType, None], b:ConstType|Variable): - return UOp._const(cast(DType, self.dtype if isinstance(self, UOp) else self).scalar() if self is not None else self, b) + def const_like(self, b:ConstType|Variable): return type(self).const(self.dtype, b) + def sconst_like(self, b:ConstType|Variable): return type(self).const(self.dtype.scalar() if self.dtype is not None else None, b) + @classmethod + @functools.lru_cache(None) + def const(cls, dtype:Optional[DType], b:ConstType|Variable): return UOp._const(cls, dtype, b) @staticmethod - @functools.lru_cache(maxsize=None) - def _const(dtype:Optional[DType], b:ConstType|Variable): + def _const(typ, dtype:Optional[DType], b:ConstType|Variable): # TODO: fix dtype of b.max after Variable is just an UOp - if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (UOp.const(dtypes.int, b.min), UOp.const(dtypes.int, cast(int,b.max))), b) + if isinstance(b, Variable): return typ(UOps.DEFINE_VAR, dtype, (typ.const(dtypes.int, b.min), typ.const(dtypes.int, cast(int,b.max))), b) if dtype is not None and dtype != (sdtype := dtype.scalar()): - return UOp(UOps.VECTORIZE, dtype, src=tuple(UOp(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count))) - return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) + return typ(UOps.VECTORIZE, dtype, src=tuple(typ(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count))) + return typ(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) def alu(self, arg, *src:UOp): return type(self)(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else (self, *src)[-1].dtype, (self,)+src, arg) @staticmethod @@ -349,7 +351,7 @@ class UOp: return 1 def divides(self, v) -> Optional[UOp]: if v==1: return self - if self.op is UOps.CONST: return self.const(self.arg//v) if self.arg%v == 0 else None + if self.op is UOps.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None if self.op is UOps.ALU: if self.arg is BinaryOps.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None if self.arg is BinaryOps.MUL: @@ -357,32 +359,34 @@ class UOp: if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1 return None # generic None if we aren't sure @property - def vmin(self) -> UOp: return x if (x:=self._min_max[0]) is not None and not math.isnan(x.arg) else self.sconst(dtypes.min(cast(DType, self.dtype))) + def vmin(self) -> UOp: + return x if (x:=self._min_max[0]) is not None and not math.isnan(x.arg) else self.sconst_like(dtypes.min(cast(DType, self.dtype))) @property - def vmax(self) -> UOp: return x if (x:=self._min_max[1]) is not None and not math.isnan(x.arg) else self.sconst(dtypes.max(cast(DType, self.dtype))) + def vmax(self) -> UOp: + return x if (x:=self._min_max[1]) is not None and not math.isnan(x.arg) else self.sconst_like(dtypes.max(cast(DType, self.dtype))) @functools.cached_property def _min_max(self) -> Tuple[Optional[UOp], Optional[UOp]]: # NOTE: returned UOp is assumed to be CONST if self.op is UOps.DEFINE_VAR and self.src: return self.src[0], self.src[1] if isinstance(self.src[1].arg, int) else None if self.op is UOps.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax # TODO: UOps.SPECIAL is UOps.DEFINE_VAR - if self.op is UOps.SPECIAL: return self.const(0), self.const(self.arg[1]-1) if isinstance(self.arg[1], int) else None + if self.op is UOps.SPECIAL: return self.const_like(0), self.const_like(self.arg[1]-1) if isinstance(self.arg[1], int) else None if self.op is UOps.CONST: return self, self if self.op is UOps.ALU and cast(DType, self.dtype).count == 1: s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)] - if self.arg is BinaryOps.ADD: return self.sconst(s0.vmin.arg+s1.vmin.arg), self.sconst(s0.vmax.arg+s1.vmax.arg) + if self.arg is BinaryOps.ADD: return self.sconst_like(s0.vmin.arg+s1.vmin.arg), self.sconst_like(s0.vmax.arg+s1.vmax.arg) if self.arg is BinaryOps.MUL and (s0.vmin.arg >= 0 or s1.vmin.arg >= 0): # handle at lease one is non-negative Lmin, Lmax = (s0.vmin.arg, s0.vmax.arg) if s1.vmin.arg >= 0 else (s0.vmax.arg, s0.vmin.arg) Rmin, Rmax = (s1.vmin.arg, s1.vmax.arg) if s0.vmin.arg >= 0 else (s1.vmax.arg, s1.vmin.arg) assert math.isnan(Lmax*Rmax) or math.isnan(Lmin*Rmin) or Lmax*Rmax >= Lmin*Rmin, f"{Lmax=}, {Lmin=}, {Rmax=}, {Rmin=}" - return self.sconst(Lmin*Rmin), self.sconst(Lmax*Rmax) - if self.arg is BinaryOps.MOD and s1.vmin.arg > 0: return self.sconst(0), self.sconst(s1.vmax.arg-1) + return self.sconst_like(Lmin*Rmin), self.sconst_like(Lmax*Rmax) + if self.arg is BinaryOps.MOD and s1.vmin.arg > 0: return self.sconst_like(0), self.sconst_like(s1.vmax.arg-1) if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST: - if s1.arg > 0: return self.sconst(s0.vmin.arg//s1.arg), self.sconst(s0.vmax.arg//s1.arg) - if s1.arg < 0: return self.sconst(-(s0.vmax.arg//-s1.arg)), self.sconst(-(s0.vmin.arg//-s1.arg)) - if self.arg is BinaryOps.MAX: return self.sconst(max(s0.vmin.arg, s1.vmin.arg)), self.sconst(max(s0.vmax.arg, s1.vmax.arg)) - if self.arg is BinaryOps.CMPLT: return (UOp.sconst(dtypes.bool, s0.vmax.arg 0: return self.sconst_like(s0.vmin.arg//s1.arg), self.sconst_like(s0.vmax.arg//s1.arg) + if s1.arg < 0: return self.sconst_like(-(s0.vmax.arg//-s1.arg)), self.sconst_like(-(s0.vmin.arg//-s1.arg)) + if self.arg is BinaryOps.MAX: return self.sconst_like(max(s0.vmin.arg, s1.vmin.arg)), self.sconst_like(max(s0.vmax.arg, s1.vmax.arg)) + if self.arg is BinaryOps.CMPLT: return (UOp.const(dtypes.bool, s0.vmax.arg Tuple[str, int]: while (frm.f_code.co_filename.endswith("/ops.py") or frm.f_code.co_filename == '') and frm.f_back is not None: frm = frm.f_back return frm.f_code.co_filename, frm.f_lineno @functools.lru_cache(None) -def lines(fn): return open(fn).readlines() +def lines(fn) -> List[str]: return open(fn).readlines() @dataclass(frozen=True, repr=False) # reuse repr from UOp class NOp(UOp): @@ -414,7 +418,11 @@ class NOp(UOp): @staticmethod @functools.lru_cache(None) def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.CONST, dtype=dtype, name=name) - def const(self:Union[UOp, DType, None], b:ConstType|Variable): return NOp((x:=UOp.const(self, b)).op, x.dtype, x.src, x.arg) + + # this is needed so NOp has a different cache + @classmethod + @functools.lru_cache(None) + def const(cls, dtype:Optional[DType], b:ConstType|Variable): return cls._const(cls, dtype, b) @functools.cached_property def upat(self:NOp) -> UPat: @@ -447,7 +455,11 @@ class UPat: upat_match = [self.in_src] if isinstance(self.in_src, UPat) else ([] if self.in_src is None else self.src[0]) self.early_reject = set((pp.op[0], pp.arg) for pp in upat_match if pp.op is not None and len(pp.op) == 1) - def printable(self:UPat): return lines(self.location[0])[self.location[1]-1].strip() + def printable(self:UPat) -> str: + try: + return lines(self.location[0])[self.location[1]-1].strip() + except FileNotFoundError: + return "" def __repr__(self): def rep(x): form = "UPat(%s, %s, name=%s, dtype=%s, allow_any_len=%s, src=%s)" diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index e36a7b8cd3..745818b1d3 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -66,7 +66,7 @@ ptx_matcher = PatternMatcher([ UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))), lambda root, alu, const: UOp(root.op, root.dtype, (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64), - const.const(root.src[0].dtype.itemsize)*const)+root.src[2:])), + const*root.src[0].dtype.itemsize)+root.src[2:])), (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_any_len=True, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}), UPat(UOps.CONST, name="const"))), lambda root, const: UOp(root.op, root.dtype,