diff --git a/test/test_edgecases.py b/test/test_edgecases.py index a38b38f3cd..026ec2fb23 100644 --- a/test/test_edgecases.py +++ b/test/test_edgecases.py @@ -26,6 +26,10 @@ import unittest import numpy as np import torch from tinygrad import Tensor, dtypes, nn +from tinygrad.device import is_dtype_supported +from tinygrad.helpers import getenv + +MOCKGPU = getenv("MOCKGPU") class TestNaNEdgeCases(unittest.TestCase): # we don't need more of these. it's unclear if torch's behavior is desired here @@ -167,34 +171,6 @@ class TestZeroFolding(unittest.TestCase): with self.assertRaises(RuntimeError): (x % x).numpy() -class TestArangeUOpValidationIssue(unittest.TestCase): - # these fail with UOp verification error. - # we don't need more of these involving arange - - @unittest.expectedFailure - def test_large_arange_sum(self): - # Summing a huge arange should either succeed or raise a MemoryError. - n = 2**31 + 3 - expected = (n - 1) * n // 2 - out = Tensor.arange(n).sum().item() - self.assertEqual(out, expected) - - @unittest.expectedFailure - def test_large_arange_index(self): - # Indexing a huge arange should return the correct value instead of failing - # with a UOp verification error. - n = 2**31 + 3 - out = Tensor.arange(n)[0].item() - self.assertEqual(out, 0) - - @unittest.expectedFailure - def test_large_arange_permute(self): - # Permuting a huge tensor should not trigger UOp verification failures. - n = 2**31 + 3 - out = Tensor.arange(n).reshape(n, 1).permute(1, 0) - self.assertEqual(out.shape, (1, n)) - out.realize() - class TestAssignIssues(unittest.TestCase): # these are good failures. i'm not sure we need more, but we need to fix these. @@ -230,10 +206,8 @@ class TestUOpValidationIssue(unittest.TestCase): # these fail with UOp verification error. # we want more of these with diverse errors! - @unittest.expectedFailure + @unittest.skipIf((not is_dtype_supported(dtypes.long)) or MOCKGPU, "hangs gpuocelot") def test_tensor_index_overflow(self): - # Advanced indexing on tensors expanded past int32 should not error, but - # tinygrad fails with a UOp verification error. val = Tensor([1]) big = val.expand(2**31 + 3) idx = Tensor([0, 2**31 + 2]) @@ -273,4 +247,4 @@ class TestEdgeCases(unittest.TestCase): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/test/test_tensor.py b/test/test_tensor.py index 6e3cb44801..902a22f041 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -928,6 +928,7 @@ class TestIdxUpcast(unittest.TestCase): self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 1, 2048).bind(32)) @unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported") + @unittest.expectedFailure # bug in gpu dims limiting def test_int64_unsupported_overflow(self): with self.assertRaises(KeyError): self.do_op_then_assert(dtypes.long, 2048, 2048, 2048) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index bc9a641e4c..293264b7e9 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -212,7 +212,7 @@ class TestUOpGraph(unittest.TestCase): def test_where_same_fold(self): v = UOp.variable('tmp', 0, 1) - c0 = UOp(Ops.CONST, dtypes.int, arg=0) + c0 = UOp(Ops.CONST, dtypes.index, arg=0) vc = UOp(Ops.CMPNE, dtypes.bool, (v, c0)) c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) out = UOp(Ops.WHERE, dtypes.float, (vc, c1, c1)) @@ -398,7 +398,7 @@ class TestUOpGraph(unittest.TestCase): self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 1) def test_depth_2_const_fold(self): - v = UOp.variable("tmp", 0, 1) + v = UOp.variable("tmp", 0, 1, dtypes.int) c2 = UOp(Ops.CONST, dtypes.int, arg=2) c4 = UOp(Ops.CONST, dtypes.int, arg=4) vc = UOp(Ops.ADD, dtypes.int, (v, c2)) @@ -417,6 +417,17 @@ class TestUOpGraph(unittest.TestCase): uops = to_uops_list([v.bitcast(dt)]) self.assertEqual(len([x for x in uops if x.op is Ops.BITCAST]), 0, f"dtype = {dt}") + def test_load_idx_becomes_int(self): + d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 0) + d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 1) + l0 = UOp(Ops.LOAD, dtypes.long, (d0.index(UOp.const(dtypes.int, 0)),)) + idx = l0 * 600 + valid = (l0<-1).ne(True)&(l0<3000) + l1 = UOp(Ops.LOAD, dtypes.long, (d1.index(idx, valid),)) + uops = to_uops_list([l1]) + for u in uops: + if u.op is Ops.INDEX: self.assertEqual(u.src[1].dtype, dtypes.int) + def test_in_out_of_bounds_access(self): with Context(IGNORE_OOB=0): glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) @@ -512,7 +523,7 @@ class TestUOpGraph(unittest.TestCase): def test_in_out_bounds_access_with_mask(self): with Context(IGNORE_OOB=0): glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) - gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 42),), "gidx0") + gidx0 = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 42),), "gidx0") ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, (5=0)&(ld0<32)),)) + gidx0 = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 42),), "gidx0") + ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<8),)).cast(dtypes.index) + ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<32)),)).cast(dtypes.index) to_uops_list([ld1]) ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<64)),)) diff --git a/test/unit/test_graph_rewrite.py b/test/unit/test_graph_rewrite.py index a8e7af3904..8655019a60 100644 --- a/test/unit/test_graph_rewrite.py +++ b/test/unit/test_graph_rewrite.py @@ -97,19 +97,22 @@ class TestFoldingAndReduction(unittest.TestCase): class TestModuloAndDivisionFolding(unittest.TestCase): def test_full_graph_rewrite_modulo_folding_with_define_var(self): - x_var_uop = UOp.variable('x', 0, 100) + # index dtype because div-mod rules only work on index + x_var_uop = UOp.variable('x', 0, 100).cast(dtypes.index) optimized_mod_uop = apply_rewrite(((x_var_uop * 4) + 2) % 4) self.assertEqual(optimized_mod_uop.op, Ops.CONST) self.assertEqual(optimized_mod_uop.arg, 2) def test_full_graph_rewrite_division_folding_with_define_var(self): - n_var_uop = UOp.variable('n', 1, 1000) + # index dtype because div-mod rules only work on index + n_var_uop = UOp.variable('n', 1, 1000).cast(dtypes.index) optimized_div_uop = apply_rewrite((n_var_uop * 6) // 3) self.assertEqual(optimized_div_uop.op, Ops.MUL) self.assertEqual(optimized_div_uop.src[1].arg, 2) def test_full_graph_rewrite_complex_mod_div_folding(self): - k_var_uop = UOp.variable('k', 0, 50) + # index dtype because div-mod rules only work on index + k_var_uop = UOp.variable('k', 0, 50).cast(dtypes.index) optimized_div_uop = apply_rewrite(((k_var_uop * 12 + 8) % 6) // 2) self.assertEqual(optimized_div_uop.op, Ops.CONST) self.assertEqual(optimized_div_uop.arg, 1) @@ -126,8 +129,9 @@ class TestModuloAndDivisionFolding(unittest.TestCase): if opt.op is Ops.VECTORIZE: self.assertFalse(all_same(opt.src)) def test_full_graph_rewrite_modulo_large_divisor(self): + # index dtype because div-mod rules only work on index x_var_uop = UOp.variable('x', 1, 5) - self.assertIs(apply_rewrite(x_var_uop % 10), x_var_uop) + self.assertIs(apply_rewrite(x_var_uop.cast(dtypes.index) % 10).render(simplify=False), x_var_uop.render(simplify=False)) def test_full_graph_rewrite_division_with_remainder(self): x_var_uop = UOp.variable('x', 7, 9) diff --git a/test/unit/test_rewrite_map.py b/test/unit/test_rewrite_map.py index adc8cd2f8a..a299888725 100644 --- a/test/unit/test_rewrite_map.py +++ b/test/unit/test_rewrite_map.py @@ -46,8 +46,8 @@ class TestRewriteMap(unittest.TestCase): def test_add_zero(self): # Build a small graph: add(0, add(const=0, const=5)) - zero_node = UOp.const(dtypes.int, 0) - five_node = UOp.const(dtypes.int, 5) + zero_node = UOp.const(dtypes.index, 0) + five_node = UOp.const(dtypes.index, 5) inner_add = zero_node + five_node root_add = zero_node + inner_add @@ -67,7 +67,7 @@ class TestRewriteMap(unittest.TestCase): Test rewriting neg(neg(5)) => 5 using symbolic. """ # In some versions of TinyGrad, you might do: (-(-five_node)) - five_node = UOp.const(dtypes.int, 5) + five_node = UOp.const(dtypes.index, 5) # If your code allows UOp(...), do that; else you might do something like: # double_neg_five = -(-five_node) # But let's be explicit: @@ -85,8 +85,8 @@ class TestRewriteMap(unittest.TestCase): """ Combine both rewrites: add(0, neg(neg(5))) => add(0, 5) => 5 """ - zero_node = UOp.const(dtypes.int, 0) - five_node = UOp.const(dtypes.int, 5) + zero_node = UOp.const(dtypes.index, 0) + five_node = UOp.const(dtypes.index, 5) neg_five = -five_node double_neg_five = -neg_five root_add = zero_node + double_neg_five @@ -103,7 +103,7 @@ class TestRewriteMap(unittest.TestCase): def test_multi_var_rewrites(self): x_var = UOp.variable('x', 0, 10) y_var = UOp.variable('y', -5, 5) - zero_node = UOp.const(dtypes.int, 0) + zero_node = UOp.const(dtypes.index, 0) sum_with_zero = y_var + zero_node # (y + 0) combined = x_var + sum_with_zero # x + (y + 0) @@ -155,8 +155,8 @@ class TestRewriteMap(unittest.TestCase): x_var = UOp.variable('x', 1, 10) y_var = UOp.variable('y', -5, 5) z_var = UOp.variable('z', 0, 5) - zero_node = UOp.const(dtypes.int, 0) - one_node = UOp.const(dtypes.int, 1) + zero_node = UOp.const(dtypes.index, 0) + one_node = UOp.const(dtypes.index, 1) # Build sub-expressions yz_sum = y_var + z_var # (y + z) diff --git a/test/unit/test_simplify_valid_idx.py b/test/unit/test_simplify_valid_idx.py index cc1e0aa8e9..d75c872209 100644 --- a/test/unit/test_simplify_valid_idx.py +++ b/test/unit/test_simplify_valid_idx.py @@ -18,7 +18,7 @@ def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UO UOp(Ops.VECTORIZE, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4) )) -def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.int32, (UOp.const(dtypes.int, nmax),), expr) +def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, nmax),), expr) def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax) def Range(n, nmax): return UOp.range(nmax, n) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 085d1b3a13..a5de0cbc73 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -2,22 +2,20 @@ import unittest, pickle, functools, math import z3 -from tinygrad.dtype import dtypes, ConstType +from tinygrad.dtype import dtypes, ConstType, DType from tinygrad.codegen import full_rewrite -from tinygrad.codegen.late.devectorizer import sym from tinygrad.helpers import Context -from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer -from tinygrad import Variable +from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer, track_rewrites +from tinygrad.uop.symbolic import sym from tinygrad.uop.spec import uops_to_z3 -def render(self) -> tuple[str, ConstType, ConstType]: - # NOTE: we need STORE so the ALU op has children - glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0) - uops = full_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), self)).sink()) - rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[1] - return rewritten_uop.render(simplify=False), rewritten_uop.vmin, rewritten_uop.vmax +@track_rewrites(name="simplify symbolic uop") +def render(v) -> UOp: + v_simplified = graph_rewrite(v, sym) + return v_simplified -def uconst(val): return UOp.const(dtypes.int, val) +def Variable(name: str, min_val: ConstType, max_val: ConstType, dtype: DType=dtypes.index): return UOp.variable(name,min_val,max_val,dtype) +def uconst(val): return UOp.const(dtypes.index, val) def usum(ops): return functools.reduce(lambda x,y: x+y, ops) def uand(ops): return functools.reduce(lambda x,y: x*y, ops) @@ -30,11 +28,12 @@ class TestSymbolicPickle(unittest.TestCase): class TestSymbolic(unittest.TestCase): def helper_test_variable(self, v, n, m, s, test_z3:bool=True): + v_simplified = render(v) if test_z3: solver = z3.Solver() - expr, expr_simplified = uops_to_z3(solver, v, v.simplify()) + expr, expr_simplified = uops_to_z3(solver, v, v_simplified) self.assertEqual(solver.check(expr != expr_simplified), z3.unsat, "simplified expression not equal to original") - rendered, nmin, nmax = render(v) + rendered, nmin, nmax = v_simplified.render(simplify=False), v_simplified.vmin, v_simplified.vmax if isinstance(s, tuple): self.assertIn(rendered, s) else: self.assertEqual(rendered, s) self.assertEqual(nmin, n) @@ -111,7 +110,7 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(-Variable("a", 0, 8), -8, 0, "(a*-1)") def test_xor_0(self): - self.helper_test_variable(Variable("a", 0, 8) ^ 0, 0, 8, "a") + self.helper_test_variable(Variable("a", 0, 8, dtypes.int) ^ 0, 0, 8, "a") def test_add_1(self): self.helper_test_variable(Variable("a", 0, 8)+1, 1, 9, "(a+1)") @@ -209,12 +208,12 @@ class TestSymbolic(unittest.TestCase): self.assertEqual((Variable("x", -10, 0)%Variable("y", 1, 10))._min_max, (-9, 0)) def test_range_div_its_symbolic_bound(self): - a = Variable("a", 1, 10) + a = Variable("a", 1, 10, dtypes.index) ridx0 = UOp.range(a+2, 0) self.helper_test_variable(ridx0//(a+2), 0, 0, "0") def test_range_mod_its_symbolic_bound(self): - a = Variable("a", 1, 10) + a = Variable("a", 1, 10, dtypes.index) ridx = UOp.range(a+2, 0) self.helper_test_variable(ridx%(a+2), 0, 11, "ridx0") @@ -463,8 +462,8 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -8, 0, "(((idx*10)//11)*-1)") def test_nest_div_negative_factor(self): - ridx0=UOp.variable("ridx0", 0, 9) - ridx1=UOp.variable("ridx1", 0, 6) + ridx0=Variable("ridx0", 0, 9) + ridx1=Variable("ridx1", 0, 6) self.helper_test_variable(((((ridx0*-7)+ridx1)+63)//35), 0, 1, "(((ridx0//5)*-1)+1)") def test_div_into_mod(self): @@ -533,8 +532,8 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(x//y, 2, 2, "2") self.helper_test_variable(x%y, 0, 7, "(x+(y*-2))") # ensure all 4 corners are checked - x = Variable("x", -10, 10) - y = Variable("y", -8, 9) + x = Variable("x", -10, 10, dtypes.int) + y = Variable("y", -8, 9, dtypes.int) self.helper_test_variable(x//y, -2147483648, 2147483647, "(x//y)") self.helper_test_variable(x%y, -2147483648, 2147483647, "(x%y)") @@ -587,6 +586,12 @@ class TestSymbolic(unittest.TestCase): unrolled_div = (gidx+2561)//4+(gidx+2562)//4+(gidx+2560)//4+(gidx+2559)//4 self.helper_test_variable(unrolled_div, 2559, 5118, "(gidx+2559)") + def test_arange_unrolled4_with_cast(self): + gidx = Variable("gidx", 0, 2559, dtypes.index) + dt = dtypes.int + unrolled_div = ((gidx+2561)//4 + 2).cast(dt)+((gidx+2562)//4).cast(dt)+((gidx+2560)//4).cast(dt)+((gidx+2559)//4).cast(dt) + self.helper_test_variable(unrolled_div, 2561, 5120, "((int)(gidx)+2561)") + def test_arange_unrolled4_mul(self): gidx = Variable("gidx", 0, 2559) unrolled_div = 2*((gidx+2561)//4)+2*((gidx+2562)//4)+2*((gidx+2560)//4)+2*((gidx+2559)//4) @@ -688,10 +693,10 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(-a<-b, False, True, "(b len(max_sizes): raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}") # try to split up dims: (a,) -> (b, c) if limited == dims: limited = _split_dims(dims, max_sizes) if max_sizes is not None else dims - ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.int, (sint_to_uop(s),), (f"{prefix}{i}")) for i,s in enumerate(limited)] + ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.index, (sint_to_uop(s),), (f"{prefix}{i}")) for i,s in enumerate(limited)] if len(limited) < len(dims): ret = [] if (contraction:=get_contraction(dims, limited)) is None: raise AssertionError(f"get_contraction should not be None {dims=} {limited=}") diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index 4414ef20c1..bec24fd5c6 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -4,7 +4,7 @@ from collections import defaultdict from dataclasses import dataclass from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite, GroupOp, identity_element -from tinygrad.uop.symbolic import uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat +from tinygrad.uop.symbolic import uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat, cast_folding from tinygrad.helpers import getenv, flatten, AMX, prod, partition from tinygrad.renderer import Renderer @@ -57,6 +57,9 @@ load_store_indexing = PatternMatcher([ (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load), # index True is just Index (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat(Ops.CONST, arg=True))), lambda buf,start_idx: buf.index(start_idx)), + # remove hanging cast + (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast()),), lambda buf,idx: buf.index(idx)), + (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid)), # delete_redundant_gates (after expand) (UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")), UPat.var("val")), name="store", allow_any_len=True), delete_redundant_gates), @@ -316,12 +319,12 @@ def no_range(u:UOp) -> bool: return not any(x.op is Ops.RANGE for x in u.sparent pm_reduce_collapse = PatternMatcher([ # lift x+y out of reduce on lt - ((UPat.var("x")+UPat.var("y")) < UPat.var("c"), lambda x,y,c: (x < (c-y)) if no_range(y) and no_range(c) else None), + ((UPat.var("x")+UPat.var("y")).or_casted() < UPat.var("c"), lambda x,y,c: (x < (c.cast(y.dtype)-y)) if no_range(y) and no_range(c) else None), # lift x*y out of reduce ((UPat.var("x")*UPat.var("y")) < UPat.var("c"), lambda x,y,c: (x < ((c+y-1) // y)) if no_range(y) and no_range(c) and y.vmin > 0 else None), # lift x+y out of reduce on ne - ((UPat.var("x")+UPat.var("y")) != UPat.var("c"), lambda x,y,c: (x != (c-y)) if no_range(y) and no_range(c) else None), + ((UPat.var("x")+UPat.var("y")).or_casted() != UPat.var("c"), lambda x,y,c: (x != (c.cast(y.dtype)-y)) if no_range(y) and no_range(c) else None), # fold the range ((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(0, UPat.cvar("val")).reduce(arg=Ops.ADD, allow_any_len=True), lambda r,cut,val: (r.src[0]-cut).maximum(0).minimum(r.src[0]).cast(val.dtype) * val), @@ -351,7 +354,7 @@ pm_reduce_collapse = PatternMatcher([ (UPat(Ops.VECTORIZE, name="inp").reduce(name="red", allow_any_len=True), no_vectorized_reduce), # index/load/where. TODO: this is more aggressive than needed (UPat((Ops.INDEX, Ops.LOAD, Ops.WHERE), name="alu"), no_vectorized_alu), -])+sym +])+sym+cast_folding def reduce_collapse(red:UOp): included, not_included = partition(red.parents, lambda x: any(y in x.sparents for y in red.src[1:])) diff --git a/tinygrad/codegen/late/expander.py b/tinygrad/codegen/late/expander.py index 51522abd88..b5c2228a7e 100644 --- a/tinygrad/codegen/late/expander.py +++ b/tinygrad/codegen/late/expander.py @@ -151,7 +151,7 @@ def fix_group_for_reduce(x:UOp): pm_pre_expander = PatternMatcher([ # rewrite UPCAST/UNROLL range to something to be expanded (UPat(Ops.RANGE, name="r"), - lambda r: UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(s:=r.vmax+1), tuple(range(s))),), ((r.arg[0],s),)) \ + lambda r: UOp(Ops.UNROLL, r.dtype, (UOp.const(r.dtype.vec(s:=r.vmax+1), tuple(range(s))),), ((r.arg[0],s),)) \ if r.arg[1] in {AxisType.UNROLL, AxisType.UPCAST} else None), # fix REDUCEs with UNROLLs (UPat(Ops.REDUCE, name="x"), fix_reduce_unroll), diff --git a/tinygrad/device.py b/tinygrad/device.py index 9a021c6a38..ded726e1e1 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -298,6 +298,7 @@ class Compiled: # TODO: move this to each Device def is_dtype_supported(dtype:DType, device:str|None=None) -> bool: + if dtype == dtypes.index: return False if device is None: device = Device.DEFAULT if dtype == dtypes.bfloat16: if device == "METAL": return not CI diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 65a58f772a..c706fbcda3 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -89,7 +89,7 @@ class dtypes: def is_float(x: DType) -> bool: return x.scalar() in dtypes.floats or isinstance(x, ImageDType) @staticmethod # static methods on top, or bool in the type info will refer to dtypes.bool @functools.cache - def is_int(x: DType) -> bool: return x.scalar() in dtypes.ints + def is_int(x: DType) -> bool: return x.scalar() in dtypes.ints + (dtypes.index,) @staticmethod @functools.cache def is_unsigned(x: DType) -> bool: return x.scalar() in dtypes.uints @@ -128,6 +128,7 @@ class dtypes: @staticmethod def fields() -> dict[str, DType]: return DTYPES_DICT void: Final[DType] = DType.new(-1, 0, "void", None) + index: Final[DType] = DType.new(-1,100, "index", None) bool: Final[DType] = DType.new(0, 1, "bool", '?') int8: Final[DType] = DType.new(1, 1, "signed char", 'b') uint8: Final[DType] = DType.new(2, 1, "unsigned char", 'B') @@ -164,7 +165,7 @@ class dtypes: uints = (uint8, uint16, uint32, uint64) sints = (int8, int16, int32, int64) ints = uints + sints - all = floats + ints + (bool,) + all = floats + ints + (bool, index) if (env_default_float := getenv("DEFAULT_FLOAT", "")): dtypes.default_float = getattr(dtypes, env_default_float.lower()) @@ -186,11 +187,12 @@ def _get_recursive_parents(dtype:DType) -> set[DType]: return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64} @functools.cache def least_upper_dtype(*ds:DType) -> DType: - return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0] + return min(set.intersection(*[_get_recursive_parents(d.scalar()) for d in ds])) \ + if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0] def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.default_float) -DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void"))} -INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void"} +DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void", "index"))} +INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void", "index":"index"} @functools.cache def can_safe_cast(dt0:DType, dt1:DType) -> bool: @@ -198,6 +200,7 @@ def can_safe_cast(dt0:DType, dt1:DType) -> bool: # https://numpy.org/doc/stable/reference/generated/numpy.can_cast.html if dt0 == dt1 or dt0 == dtypes.bool: return True match dt1: + case dtypes.index: return dt0 in dtypes.ints case dtypes.double: return dt0 in (dtypes.float, dtypes.half, dtypes.bfloat16, dtypes.uint32, dtypes.uint16, dtypes.uint8, dtypes.int32, dtypes.int16, dtypes.int8) case dtypes.float: return dt0 in (dtypes.half, dtypes.bfloat16, dtypes.uint16, dtypes.uint8, dtypes.int16, dtypes.int8) @@ -315,4 +318,4 @@ def _to_torch_dtype(dtype:DType) -> 'torch.dtype'|None: # type: ignore [name-de except TypeError: return None @functools.cache def _from_torch_dtype(torchdtype:'torch.dtype') -> DType: # type: ignore [name-defined] # noqa: F821 - return {v:k for k in dtypes.all if (v:=_to_torch_dtype(k)) is not None}[torchdtype] + return {v:k for k in DTYPES_DICT.values() if (v:=_to_torch_dtype(k)) is not None}[torchdtype] diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 4a527edd57..9883deb126 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -119,7 +119,7 @@ def map_reshape(idx:UOp, r:UOp): for s,src in list(zip(idx.shape, idx.src[1:]))[::-1]: to_sum.append(acc*src) acc *= s - mish = sum(to_sum, start=UOp.const(dtypes.int, 0)) + mish = sum(to_sum, start=UOp.const(dtypes.index, 0)) ret:list[UOp] = [] for s in r.src[0].shape[::-1]: ret.append(mish % s) # NOTE: simplify will turn this to CONST @@ -186,7 +186,7 @@ def map_partial_contiguous(ctx:RangeifyContext, x:UOp, idx:UOp): ranges.append(idx.src[1+i]) continue passthrough_idx.append(idx.src[1+i]) - ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.int, 0)) + ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.index, 0)) new_ranges.append(ranges[-1]) ret = x.src[0].index(*ranges).bufferize(*[x for x in new_ranges if x.op is not Ops.CONST], arg=x.device) return ret.index(*passthrough_idx) @@ -195,7 +195,7 @@ def map_contiguous(ctx:RangeifyContext, x:UOp): if x.arg is not None: return None ranges = [] for s in x.shape[len(x.src)-1:]: - ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.int, 0)) + ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.index, 0)) ret = x.src[0].index(*ranges).bufferize(*x.src[1:], *[x for x in ranges if x.op is not Ops.CONST], arg=x.device) return ret.shrink(((0, prod(x.shape)),)).forced_reshape(x.shape) diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index b487152fa5..9372dab7b5 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -5,22 +5,8 @@ import functools from typing import Callable from tinygrad.helpers import merge_dicts, getenv from tinygrad.shape.view import View, unravel -from tinygrad.dtype import dtypes -from tinygrad.uop.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context, PatternMatcher, UPat, GroupOp from tinygrad.uop.symbolic import symbolic_flat, uop_given_valid, simplify_valid - -# If a node overflow, its srcs need to be checked to see if this overflow is the result of an ALU operation, -# or that the node simply inherits the dtype from srcs. Upcast is either `Ops.CAST`+`replace` or just `replace`. -def handle_upcast(u: UOp) -> UOp|None: - dtype = dtypes.int64.vec(u.dtype.count) if u.dtype.count > 1 else dtypes.int64 - # check for overflow, upcast this to int64 - if u.vmax > dtypes.max(dtypes.int) or u.vmin < dtypes.min(dtypes.int): - return u.replace(dtype=dtype, src=tuple([x.cast(dtype) for x in u.src])) - # if any inputs are int64 and this *doesn't* overflow, cast back to int - if any(x.dtype == dtypes.int64 for x in u.src): - return u.replace(dtype=dtype, src=tuple([x.cast(dtype) for x in u.src])).cast(u.dtype) - return None -pm_upcast = PatternMatcher([(UPat(GroupOp.ALU, dtype=dtypes.int, name="u"), handle_upcast),]) +from tinygrad.uop.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context @functools.cache def views_to_indexed_uops(views: tuple[View, ...], _idxs:tuple[UOp, ...]|None=None) -> tuple[UOp, UOp]: @@ -34,8 +20,8 @@ def views_to_indexed_uops(views: tuple[View, ...], _idxs:tuple[UOp, ...]|None=No # simplify if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid if (newidx:=uop_given_valid(valid, idx)) is not None: idx = newidx - # symbolic again, upcast if needed - return graph_rewrite(UOp.sink(idx, valid), symbolic_flat+pm_upcast, name="indexing sym @ 2").src + # symbolic again + return graph_rewrite(UOp.sink(idx, valid), symbolic_flat, name="indexing sym @ 2").src @functools.cache def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[sint|None, ...]: diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index 6b4cd22bb6..9ab4687a82 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -204,15 +204,15 @@ class View: # Merge dimensions in vm2 if required. # NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required. - idxs: list[UOp] = [UOp.variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)] - merged_size, merged_term = 1, UOp.const(dtypes.int, 0) + idxs: list[UOp] = [UOp.variable(f"idx{i}", 0, s-1, dtypes.index) for i,s in enumerate(vm1.shape)] + merged_size, merged_term = 1, UOp.const(dtypes.index, 0) extents: list[tuple[sint, UOp]] = [] for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)): merged_term += (sum([idxs[d1] * s1 for d1, s1 in term]) + o) * merged_size merged_size *= s if resolve(merged_term < merged_size, False) and resolve(0 <= merged_term, False): extents.append((merged_size, merged_term)) - merged_size, merged_term = 1, UOp.const(dtypes.int, 0) + merged_size, merged_term = 1, UOp.const(dtypes.index, 0) if resolve(merged_term != 0): return None if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape: if (reshaped_vm2 := vm2.reshape(vm2_shape)) is None: return None diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 67deb88e55..6f2d7d275b 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -8,7 +8,7 @@ from tinygrad.dtype import _from_np_dtype, _to_np_dtype from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray, RANGEIFY, FUSE_ATTENTION from tinygrad.gradient import compute_gradient -from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, MathTrait, identity_element, all_metadata +from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, MathTrait, identity_element, all_metadata, index_to_concrete_int from tinygrad.uop.spec import tensor_uop_spec, type_verify from tinygrad.device import Device, Buffer from tinygrad.engine.realize import run_schedule @@ -139,6 +139,8 @@ class Tensor(MathTrait): # create a UOp from the different types of inputs if isinstance(data, UOp): assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported" + # if data is dtype.index that means that this is a symbolic int and we need to lower it to something we can make a Tensor out of + if data.dtype==dtypes.index: data = index_to_concrete_int(data) if data.op is Ops.BIND: var, val = data.unbind() # give the bound constant a device diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 48fd5e542f..ec8b41291a 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field from enum import Enum, auto from tinygrad.uop import Ops, GroupOp from tinygrad.uop.mathtraits import MathTrait -from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType +from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey if TYPE_CHECKING: @@ -307,7 +307,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def range(end:sint, *arg): if len(arg) == 0: raise RuntimeError("range needs an arg") if len(arg) == 1: arg = arg+(AxisType.LOOP,) - return UOp(Ops.RANGE, dtype=dtypes.int, src=(sint_to_uop(end),), arg=arg) + return UOp(Ops.RANGE, dtype=dtypes.index, src=(sint_to_uop(end),), arg=arg) def r(self, op:Ops, axis:tuple[int, ...]): axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)])) if len(axis) == 0: return self @@ -482,7 +482,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): # *** uop Variable stuff *** @staticmethod - def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.int) -> UOp: + def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.index) -> UOp: assert not isinstance(min_val, UOp) and not isinstance(max_val, UOp), f"can't create Variable {name} with {min_val}/{max_val}" return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val)) @property @@ -573,7 +573,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.op is Ops.VCONST: return (min(self.arg), max(self.arg)) if self.op is Ops.GEP: return self.src[0]._min_max # TODO: CAST to bool/unsigned is not monotone, still some case can be simplified - if self.op is Ops.CAST and self.dtype in (dtypes.floats+dtypes.sints): + if self.op is Ops.CAST and self.dtype in dtypes.floats+dtypes.sints+(dtypes.index,): return max(dtypes.min(self.dtype), self.src[0].vmin), min(self.src[0].vmax, dtypes.max(self.dtype)) return dtypes.min(self.dtype), dtypes.max(self.dtype) @@ -1025,7 +1025,30 @@ def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, na for k,v in input_map.items(): new_map[k] = new_map.get(v,v) return new_map -def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x +def sint_to_uop(x:sint) -> UOp: return UOp.const(dtypes.index, x) if isinstance(x, int) else x.cast(dtypes.index) + +def select_dtype(u): return (dtypes.long if u.overflows(dtypes.int32) else dtypes.int).vec(u.dtype.count) +pm_lower_index_dtype = PatternMatcher([ + # There are no Unary ops at this point in symbolic, those are introduced later + (UPat(GroupOp.Binary, dtypes.index, name="u", src=(UPat.var("x"), UPat.var("y"))), lambda u,x,y: + x.cast(dt:=least_upper_dtype(select_dtype(u), x.dtype, y.dtype)).alu(u.op, y.cast(dt))), + # comparison ops might now have different dtypes in their sources + (UPat(GroupOp.Comparison, name="u", src=(UPat.var("x",dtypes.ints), UPat.var("y", dtypes.ints))), lambda u,x,y: + x.cast(dt:=least_upper_dtype(x.dtype, y.dtype)).alu(u.op, y.cast(dt)) if x.dtype!=y.dtype else None), + (UPat(Ops.WHERE, dtype=dtypes.index, src=(UPat.var("cond"), UPat.var("x"), UPat.var("y")), name="u"), lambda cond,u,x,y: + cond.where(x.cast(dt:=least_upper_dtype(x.dtype, y.dtype)), y.cast(dt))), + (UPat((Ops.CONST, Ops.VCONST), dtype=dtypes.index, name="u"), lambda u: u.replace(dtype=select_dtype(u))), + (UPat((Ops.RANGE,), dtype=dtypes.index, src=(UPat.var("end")), name="r"), lambda ctx,r,end: + r.replace(dtype=(dt:=select_dtype(r)), src=(end.cast(dt),))), + (UPat(Ops.CAST, dtype=dtypes.index, src=(UPat.var("x", dtypes.ints),), name="u"), lambda u,x: x), + (UPat(Ops.VECTORIZE, dtype=dtypes.index, name="u"), lambda u: u.replace( + dtype=(dt:=least_upper_dtype(*[x.dtype for x in u.src])).vec(u.dtype.count), src=tuple(x.cast(dt) for x in u.src))), + (UPat(Ops.VECTORIZE, dtype=dtypes.index, name="u"), lambda u: u.replace(dtype=(dt:=(dtypes.long if any(v.overflows(dtypes.int) for v in u.src) + else dtypes.long)).vec(u.dtype.count),src=tuple(x.cast(dt) for x in u.src))), + (UPat((Ops.SPECIAL,Ops.DEFINE_VAR), dtypes.index, name="u"), lambda u: u.replace(dtype=dtypes.int)), + (UPat((Ops.BIND), dtypes.index, name="u"), lambda u: u.replace(dtype=u.src[0].dtype)), +]) +def index_to_concrete_int(u:UOp): return graph_rewrite(u, pm_lower_index_dtype) _substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))]) diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index c119949571..b3066ad5e1 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -28,17 +28,17 @@ try: # float loads only become a variable when they get cast to int/bool (UPat(Ops.LOAD, dtypes.ints, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0])))), - (UPat(Ops.CONST, dtype=dtypes.ints+(dtypes.bool,), name="x"), + (UPat(Ops.CONST, dtype=dtypes.ints+(dtypes.bool,dtypes.index), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx)))), # z3 can cast from bool to int automatically - (UPat(Ops.CAST, dtype=dtypes.ints, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]), + (UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]), (UPat(Ops.CAST, dtype=dtypes.bool, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], x.src[0].arg[1]!=0))), # if the source of the cast is not a noop it means that it is a float and so we create a new variable - (UPat(Ops.CAST, dtype=dtypes.ints, name="x"), lambda x,ctx: + (UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], create_bounded(f"cast{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0])))), (UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], z3.Bool(f"cast{ctx[1].setdefault(x, len(ctx[1]))}",ctx=ctx[0].ctx)))), - (UPat(Ops.XOR, src=UPat(Ops.NOOP), name="x"), + (UPat(Ops.XOR, dtype=dtypes.ints+(dtypes.bool, ), src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], z3.BV2Int(z3_alu[x.op](*(z3.Int2BV(s.arg[1], x.dtype.itemsize*8) for s in x.src)))))), (UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], z3_alu[x.op](*(s.arg[1] for s in x.src))))), # A comparison between floats introduces a new bool variable @@ -95,7 +95,7 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([ (UPat(Ops.VIEW, src=(UPat.var("x"),)), lambda x: x.base.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN, Ops.CONST, Ops.DEVICE}), # Tensor variable bindings - (UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True), + (UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.index,))), arg=None), lambda: True), # Tensor const has a device and an unmasked ShapeTracker of stride 0 # NOTE: variables in shape can cause multiple views in this ShapeTracker and other issues, see TestSymbolicJit.test_ones_sum @@ -120,6 +120,7 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([ # ***** uop type spec ***** def validate_index(idx:UOp, gate:UOp=UOp.const(dtypes.bool, True)): + # TODO: check for overflow if IGNORE_OOB or isinstance(idx.dtype, ImageDType) or (sz := idx.src[0].ptrdtype.size) == -1: return True # We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask if 0<=idx.src[1].vmin and idx.src[1].vmax (UPat(Ops.INDEX, src=(UPat(GroupOp.Defines), UPat())), lambda: True), diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index 2a9b825477..40e07e067e 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -26,7 +26,7 @@ symbolic_simple = PatternMatcher([ # ** self folding ** (UPat.var("x") + 0, lambda x: x), # x+0 -> x (UPat.var("x") * 1, lambda x: x), # x*1 -> x - (UPat.var("x", dtype=dtypes.ints) ^ 0, lambda x: x), # x^0 -> x + (UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)) ^ 0, lambda x: x), # x^0 -> x (UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1 (UPat.var("x") // 1, lambda x: x), # x//1 -> x (UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x @@ -49,11 +49,11 @@ symbolic_simple = PatternMatcher([ (UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x), (UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x), (UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, False), UPat.const(dtypes.bool, True)), lambda x: x.logical_not()), - (UPat.var("x", dtype=dtypes.ints+(dtypes.bool,)).trunc(), lambda x: x), + (UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)).trunc(), lambda x: x), # ** zero folding ** (UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False (UPat.var("x") % UPat.var("x"), lambda x: x.const_like(0)), # x%x -> 0 - (UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints), + (UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)) != UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints) # x*0 -> 0 or 0*x -> 0 # if x is nan or inf it should render the nan value. @@ -108,6 +108,7 @@ def fold_unrolled_divs(divs:UOp, denominator: int, fac=1) -> UOp|None: if fac!=1: if u.op is not Ops.MUL or u.src[1].op is not Ops.CONST or u.src[1].arg != fac: return None u = u.src[0] + if u.op is Ops.CAST and u.src[0].dtype == dtypes.index: u = u.src[0] if not (u.op is Ops.IDIV and u.src[1].op is Ops.CONST): return None if denominator != u.src[1].arg: return None if (s0:=u.src[0]).vmin < 0: return None @@ -123,7 +124,7 @@ def fold_unrolled_divs(divs:UOp, denominator: int, fac=1) -> UOp|None: for i in range(denominator-len(seen_const)): if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i) if sorted(seen_const)==list(range(denominator)): - return fac*ans + return (fac*ans).cast(divs.dtype) return None def lt_folding(x:UOp, c:int) -> UOp|None: @@ -270,10 +271,19 @@ gep_pushing = PatternMatcher([ (UPat(Ops.WMMA, name="wmma").f(Ops.GEP, name="gep"), gep_through_wmma), ]) +cast_folding = PatternMatcher([ + (UPat.var('x', dtypes.ints+(dtypes.index,)).cast(dtypes.ints+(dtypes.index,), name="a").cast(name="b"), + lambda x,a,b: x.cast(b.dtype) if a.dtype.min<=x.vmin and x.vmax<=a.dtype.max else None), + # try to do math in int instead of long + (UPat(GroupOp.Binary, src=(UPat.var("x", dtypes.long), UPat.var("y", dtypes.long)), name="u"), lambda u,x,y: + x.cast(dtypes.int).alu(u.op, y.cast(dtypes.int)).cast(u.dtype) if not any(v.overflows(dtypes.int) for v in (u,x,y)) else None), + ((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)), +]) + commutative = PatternMatcher([ - # ** COMMUTATIVE flipping (only for ints) ** + # ** COMMUTATIVE flipping (only for index) ** # NOTE: this can break merging vector math by only flipping some of them - (UPat(GroupOp.Commutative, dtype=dtypes.int, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None), + (UPat(GroupOp.Commutative, dtype=dtypes.index, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None), ]) symbolic = symbolic_simple+commutative+PatternMatcher([ @@ -289,10 +299,6 @@ symbolic = symbolic_simple+commutative+PatternMatcher([ ((UPat.var("y") + UPat.var("x")) + UPat.var("x"), lambda y,x: y+x*2), ((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3) if x2 is not x3 else None), # (x/x2)/x3 -> x/(x2*x3) (-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c - (UPat.var('x', dtypes.ints).cast(dtypes.ints, name="a").cast(name="b"), - lambda x,a,b: x.cast(b.dtype) if a.dtype.min<=x.vmin and x.vmax<=a.dtype.max else None), - (UPat(GroupOp.Binary, src=(UPat.var("x",dtypes.long), UPat.var("y", dtypes.long)), name="u"), lambda u,x,y: - x.cast(dtypes.int).alu(u.op, y.cast(dtypes.int)).cast(u.dtype) if not any(v.overflows(dtypes.int) for v in (u,x,y)) else None), # a conditional with the same results either way is a noop, also fold const conditionals (UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val), (UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1), @@ -317,27 +323,27 @@ symbolic = symbolic_simple+commutative+PatternMatcher([ ((UPat.var("x") // UPat.cvar("c1")) // UPat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2) # ** lt ** # c0*x 0 and c1.arg > 0 else None), # c0*x 0 else x<(c.arg*d.arg-(d.arg-1))) if d.arg > 0 else None), # ** move add/mul consts to end (NOTE: this is still happening before constant folding) ** ((UPat.var("x") + UPat.cvar("c1")) + UPat.var("y"), lambda x,c1,y: (x+y)+c1), ((UPat.var("x") * UPat.cvar("c1")) * UPat.var("y"), lambda x,c1,y: (x*y)*c1), # *** rules from symbolic *** # unrolled arange div folding - ((UPat() + UPat()//UPat.cvar("d", vec=False)).named("divs"), lambda divs,d: fold_unrolled_divs(divs, d.arg)), - ((UPat() + (UPat()//UPat.cvar("d", vec=False))*UPat.cvar("c")).named("divs"), lambda divs,d,c: fold_unrolled_divs(divs, d.arg, c.arg)), + ((UPat()+(UPat()//UPat.cvar("d", vec=False)).or_casted()).named("divs"), lambda divs,d: fold_unrolled_divs(divs, d.arg)), + ((UPat()+((UPat()//UPat.cvar("d", vec=False)).or_casted()*UPat.cvar("c"))).named("divs"), lambda divs,d,c: fold_unrolled_divs(divs, d.arg, c.arg)), # generic lt folding - (UPat.var("x", dtypes.sints) 0 # not x < 1 -> X > 0 - ((UPat.var("x", dtypes.ints)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None), + ((UPat.var("x", dtypes.index)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None), # ** div ** # div folding ((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d) @@ -345,30 +351,28 @@ symbolic = symbolic_simple+commutative+PatternMatcher([ # a range mod its own upper bound is just the range (UPat(Ops.RANGE, src=UPat.var("end"), name="r")%UPat.var("end"), lambda r,end: r), (UPat(Ops.RANGE, src=UPat.var("end"), name="r")//UPat.var("end"), lambda r,end: r.const_like(0)), - (UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.var("y"))), cancel_divmod), - (UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_binary_numerator), - (UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_divmod_congruence), - (UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), divide_by_gcd), - (UPat(Ops.MOD, dtypes.sints, name="m", src=(UPat.var("x"), UPat.cvar("y", vec=False))), remove_nested_mod), - (UPat((Ops.IDIV), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), nest_div_by_smallest_factor), - (UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), simplify_remainder), + (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.var("y"))), cancel_divmod), + (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_binary_numerator), + (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_divmod_congruence), + (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), divide_by_gcd), + (UPat(Ops.MOD, dtypes.index, name="m", src=(UPat.var("x"), UPat.cvar("y", vec=False))), remove_nested_mod), + (UPat((Ops.IDIV), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), nest_div_by_smallest_factor), + (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), simplify_remainder), (UPat.var("x") // UPat.var("d"), lambda x,d: -(x//(-d)) if d.vmax < 0 else None), (UPat.var("x") // UPat.var("d"), lambda x,d: -((-x)//d) if x.vmax <=0 else None), - ((UPat.var("x", dtypes.sints)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False), + ((UPat.var("x", dtypes.index)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False), lambda x,c,n,d: (-(-(c.arg%d.arg + x - (d.arg-1))//d) + c.arg//d.arg) if x.vmax<=0 and n.vmin>=0 and d.arg>0 else None), # ** mod ** # mod folding (UPat.var("x") % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <= 0 else None), (UPat.var("x") % UPat.var("d"), lambda x,d: (x%(-d)) if d.vmax < 0 else None), - # up + x//c*c + x%c - (UPat.var("up") + UPat.var("x", dtypes.ints)//UPat.cvar("c")*UPat.cvar("c") + UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda up,x,c: up+x), -])+gep_pushing +])+gep_pushing+cast_folding symbolic_flat = symbolic+PatternMatcher([ # ** combine terms (opinionated) ** (-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y # (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue - ((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c), + ((UPat.var("x", dtypes.index) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c), ]) # ******** we take a small aside to "simplify_valid" to rewrite valids ******** @@ -400,6 +404,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None: # simplify uop given that valid is True for expr,v in bounds.items(): v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1]) + expr = expr.substitute(load_subs) # make sure expr appears in same form in the uop # some expr has lower bound > upper bound -> valid is an empty set and we return None if v0 > v1: return None # whole node became a const