diff --git a/extra/gemm/cdna_asm_gemm.py b/extra/gemm/cdna_asm_gemm.py index 6b9f373154..fe69e61c82 100644 --- a/extra/gemm/cdna_asm_gemm.py +++ b/extra/gemm/cdna_asm_gemm.py @@ -2666,10 +2666,10 @@ def custom_uop_gemm(C:UOp, A:UOp, B:UOp) -> UOp: m = UOp.range(M, 1, AxisType.LOOP) n = UOp.range(N, 2, AxisType.LOOP) k = UOp.range(K, 0, AxisType.REDUCE) - mul = (A.flatten().index((m*UOp.const(dtypes.index, K)+k))* - B.flatten().index((k*UOp.const(dtypes.index, N)+n))).cast(dtypes.float32) + mul = (A.flatten().index((m*UOp.const(dtypes.weakint, K)+k))* + B.flatten().index((k*UOp.const(dtypes.weakint, N)+n))).cast(dtypes.float32) red = mul.reduce(k, arg=Ops.ADD, dtype=dtypes.float32).cast(C.dtype.base) - store = C.flatten().index((m*UOp.const(dtypes.index, N)+n), ptr=True).store(red).end(m, n) + store = C.flatten().index((m*UOp.const(dtypes.weakint, N)+n), ptr=True).store(red).end(m, n) return store.sink(arg=KernelInfo(name=f'uop_gemm_{M}_{N}_{K}')) # ** backward gemm, might use the asm gemm diff --git a/test/backend/test_linearizer_dumb.py b/test/backend/test_linearizer_dumb.py index 426060a577..b56fa7bbb6 100644 --- a/test/backend/test_linearizer_dumb.py +++ b/test/backend/test_linearizer_dumb.py @@ -12,18 +12,18 @@ class TestLinearizerFailure(unittest.TestCase): @unittest.skipUnless(Device.DEFAULT == "METAL", "only tested on METAL") def test_failure_beam_mnist(self): c0 = UOp(Ops.PARAM, dtypes.uchar.ptr(4014080), arg=0, src=()) - c1 = UOp.range(UOp.const(dtypes.index, 512), 0, AxisType.GLOBAL) - c2 = UOp.range(UOp.const(dtypes.index, 784), 1, AxisType.GLOBAL) - c3 = UOp.range(UOp.const(dtypes.index, 10), 3, AxisType.GLOBAL) + c1 = UOp.range(UOp.const(dtypes.weakint, 512), 0, AxisType.GLOBAL) + c2 = UOp.range(UOp.const(dtypes.weakint, 784), 1, AxisType.GLOBAL) + c3 = UOp.range(UOp.const(dtypes.weakint, 10), 3, AxisType.GLOBAL) c4 = UOp(Ops.PARAM, dtypes.int.ptr(512), arg=1, src=()) c5 = c4.index(c1.valid(UOp.const(dtypes.bool, True))) - c6 = UOp.range(UOp.const(dtypes.index, 6000), 1004, AxisType.REDUCE) - c7 = UOp.range(UOp.const(dtypes.index, 3750), 2006, AxisType.REDUCE) - c8 = UOp.range(UOp.const(dtypes.index, 16), 2007, AxisType.GROUP_REDUCE) + c6 = UOp.range(UOp.const(dtypes.weakint, 6000), 1004, AxisType.REDUCE) + c7 = UOp.range(UOp.const(dtypes.weakint, 3750), 2006, AxisType.REDUCE) + c8 = UOp.range(UOp.const(dtypes.weakint, 16), 2007, AxisType.GROUP_REDUCE) c9 = UOp(Ops.PARAM, dtypes.uchar.ptr(47040000), arg=2, src=()) - c10 = c9.index((((c3*UOp.const(dtypes.index, 4704000))+c2)+(c6*UOp.const(dtypes.index, 784))).valid(UOp.const(dtypes.bool, True))) - c11 = c5.alu(Ops.CMPNE, ((((c3*UOp.const(dtypes.index, 6000))+c6)+((c7*UOp.const(dtypes.index, 16))+c8)).alu(Ops.CMPLT, UOp.const(dtypes.index, 59999)).where(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)).reduce(c7, c8, arg=Ops.ADD)+UOp.const(dtypes.int, -1))).where(UOp.const(dtypes.uchar, 0), c10).reduce(c6, arg=Ops.ADD) - c12 = c0.index((((c1*UOp.const(dtypes.index, 7840))+(c2*UOp.const(dtypes.index, 10)))+c3).valid(UOp.const(dtypes.bool, True))).store(c11).end(c1, c2, c3) + c10 = c9.index((((c3*UOp.const(dtypes.weakint, 4704000))+c2)+(c6*UOp.const(dtypes.weakint, 784))).valid(UOp.const(dtypes.bool, True))) + c11 = c5.alu(Ops.CMPNE, ((((c3*UOp.const(dtypes.weakint, 6000))+c6)+((c7*UOp.const(dtypes.weakint, 16))+c8)).alu(Ops.CMPLT, UOp.const(dtypes.weakint, 59999)).where(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)).reduce(c7, c8, arg=Ops.ADD)+UOp.const(dtypes.int, -1))).where(UOp.const(dtypes.uchar, 0), c10).reduce(c6, arg=Ops.ADD) + c12 = c0.index((((c1*UOp.const(dtypes.weakint, 7840))+(c2*UOp.const(dtypes.weakint, 10)))+c3).valid(UOp.const(dtypes.bool, True))).store(c11).end(c1, c2, c3) ast = c12.sink(arg=KernelInfo(name='test', axis_types=(), dont_use_locals=False, applied_opts=(Opt(op=OptOps.GROUP, axis=1, arg=16),), opts_to_apply=None)) _ = get_program(ast, Device["METAL"].renderer) diff --git a/test/external/external_benchmark_op_conv.py b/test/external/external_benchmark_op_conv.py index cc8c3c00f8..144cda7661 100644 --- a/test/external/external_benchmark_op_conv.py +++ b/test/external/external_benchmark_op_conv.py @@ -23,7 +23,7 @@ def vision_conv_143(): c32 = ((c27<3)!=True)&(c27<67) c34 = UOp(Ops.PARAM, dtypes.imageh((32, 1024, 4)), (), 1) c38 = c5//2 - c45 = (c32&c24).where((c27*64+c38+c17*4096+-12480), UOp.const(dtypes.index, Invalid)) + c45 = (c32&c24).where((c27*64+c38+c17*4096+-12480), UOp.const(dtypes.weakint, Invalid)) c48 = (c24&c32).where(c34.index(c45), UOp.const(dtypes.float, 0.0)) c49 = UOp(Ops.PARAM, dtypes.imageh((64, 49, 4)), (), 2) c61 = c48*c49.index((c26*4+c5%2+c16*28+c38*196)) @@ -49,7 +49,7 @@ def vision_conv_153(): c32 = ((c27<3)!=True)&(c27<35) c34 = UOp(Ops.PARAM, dtypes.imageh((16, 1024, 4)), (), 1) c38 = c5//2 - c45 = (c32&c24).where((c27*128+c38+c17*4096+-12672), UOp.const(dtypes.index, Invalid)) + c45 = (c32&c24).where((c27*128+c38+c17*4096+-12672), UOp.const(dtypes.weakint, Invalid)) c48 = (c24&c32).where(c34.index(c45), UOp.const(dtypes.float, 0.0)) c49 = UOp(Ops.PARAM, dtypes.imageh((128, 49, 4)), (), 2) c61 = c48*c49.index((c26*4+c5%2+c16*28+c38*196)) diff --git a/test/null/test_gpudims.py b/test/null/test_gpudims.py index a3e9956a7f..53b0856612 100644 --- a/test/null/test_gpudims.py +++ b/test/null/test_gpudims.py @@ -24,7 +24,7 @@ class TestGroupedDims(unittest.TestCase): total = math.prod(dims) specials = sorted(dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs])), key=lambda u: u.arg) # build flat index and primed flat (same expression with renamed SPECIALs) - flat = UOp.const(dtypes.index, 0) + flat = UOp.const(dtypes.weakint, 0) for i, idx in enumerate(idxs): flat = flat + idx * int(math.prod(dims[i+1:])) flat_p = flat.substitute({s: UOp(Ops.SPECIAL, s.dtype, s.src, s.arg+"_p") for s in specials}) diff --git a/test/null/test_graph_rewrite.py b/test/null/test_graph_rewrite.py index c25dfc342f..24b5315c5b 100644 --- a/test/null/test_graph_rewrite.py +++ b/test/null/test_graph_rewrite.py @@ -99,21 +99,21 @@ class TestFoldingAndReduction(unittest.TestCase): class TestModuloAndDivisionFolding(unittest.TestCase): def test_full_graph_rewrite_modulo_folding_with_define_var(self): # index dtype because div-mod rules only work on index - x_var_uop = UOp.variable('x', 0, 100).cast(dtypes.index) + x_var_uop = UOp.variable('x', 0, 100).cast(dtypes.weakint) optimized_mod_uop = apply_rewrite(((x_var_uop * 4) + 2) % 4) self.assertEqual(optimized_mod_uop.op, Ops.CONST) self.assertEqual(optimized_mod_uop.arg, 2) def test_full_graph_rewrite_division_folding_with_define_var(self): # index dtype because div-mod rules only work on index - n_var_uop = UOp.variable('n', 1, 1000).cast(dtypes.index) + n_var_uop = UOp.variable('n', 1, 1000).cast(dtypes.weakint) optimized_div_uop = apply_rewrite((n_var_uop * 6) // 3) self.assertEqual(optimized_div_uop.op, Ops.MUL) self.assertEqual(optimized_div_uop.src[1].arg, 2) def test_full_graph_rewrite_complex_mod_div_folding(self): # index dtype because div-mod rules only work on index - k_var_uop = UOp.variable('k', 0, 50).cast(dtypes.index) + k_var_uop = UOp.variable('k', 0, 50).cast(dtypes.weakint) optimized_div_uop = apply_rewrite(((k_var_uop * 12 + 8) % 6) // 2) self.assertEqual(optimized_div_uop.op, Ops.CONST) self.assertEqual(optimized_div_uop.arg, 1) @@ -132,7 +132,7 @@ class TestModuloAndDivisionFolding(unittest.TestCase): def test_full_graph_rewrite_modulo_large_divisor(self): # index dtype because div-mod rules only work on index x_var_uop = UOp.variable('x', 1, 5) - self.assertIs(apply_rewrite(x_var_uop.cast(dtypes.index) % 10).render(simplify=False), x_var_uop.render(simplify=False)) + self.assertIs(apply_rewrite(x_var_uop.cast(dtypes.weakint) % 10).render(simplify=False), x_var_uop.render(simplify=False)) def test_full_graph_rewrite_division_with_remainder(self): x_var_uop = UOp.variable('x', 7, 9) diff --git a/test/null/test_linearizer_failures.py b/test/null/test_linearizer_failures.py index 955a8414dc..b8636c8ac2 100644 --- a/test/null/test_linearizer_failures.py +++ b/test/null/test_linearizer_failures.py @@ -8,12 +8,12 @@ from tinygrad.device import Device class TestLinearizerFailures(unittest.TestCase): def test_fail_1(self): c0 = UOp(Ops.PARAM, dtypes.float.ptr(64), arg=0, src=()) - c1 = UOp.range(UOp.const(dtypes.index, 2), 1, AxisType.LOOP) - c2 = UOp.range(UOp.const(dtypes.index, 32), 2, AxisType.LOOP) - c3 = ((c1*UOp.const(dtypes.index, 32))+c2) + c1 = UOp.range(UOp.const(dtypes.weakint, 2), 1, AxisType.LOOP) + c2 = UOp.range(UOp.const(dtypes.weakint, 32), 2, AxisType.LOOP) + c3 = ((c1*UOp.const(dtypes.weakint, 32))+c2) c4 = UOp(Ops.PARAM, dtypes.float.ptr(163840), arg=1, src=()) - c5 = UOp.range(UOp.const(dtypes.index, 2560), 0, AxisType.REDUCE) - c6 = c4.index(((((((c5//UOp.const(dtypes.index, 8))%UOp.const(dtypes.index, 8))*UOp.const(dtypes.index, 8))+(c5%UOp.const(dtypes.index, 8)))+(((c2*UOp.const(dtypes.index, 40))+(c5//UOp.const(dtypes.index, 64)))*UOp.const(dtypes.index, 64)))+(c1*UOp.const(dtypes.index, 81920)))) + c5 = UOp.range(UOp.const(dtypes.weakint, 2560), 0, AxisType.REDUCE) + c6 = c4.index(((((((c5//UOp.const(dtypes.weakint, 8))%UOp.const(dtypes.weakint, 8))*UOp.const(dtypes.weakint, 8))+(c5%UOp.const(dtypes.weakint, 8)))+(((c2*UOp.const(dtypes.weakint, 40))+(c5//UOp.const(dtypes.weakint, 64)))*UOp.const(dtypes.weakint, 64)))+(c1*UOp.const(dtypes.weakint, 81920)))) c7 = UOp(Ops.PARAM, dtypes.float.ptr(64), arg=2, src=()) c8 = c7.index(c3) c9 = ((((c6+(c8*UOp.const(dtypes.float, -1.0)))*(c6+(c8*UOp.const(dtypes.float, -1.0)))).reduce(c5, arg=Ops.ADD)*UOp.const(dtypes.float, 0.000390625))+UOp.const(dtypes.float, 1e-05)).sqrt().reciprocal() diff --git a/test/null/test_simplify_valid_idx.py b/test/null/test_simplify_valid_idx.py index eaaf297228..71b5a35817 100644 --- a/test/null/test_simplify_valid_idx.py +++ b/test/null/test_simplify_valid_idx.py @@ -15,11 +15,11 @@ def get_gated_load_uop(valid:UOp, idx:UOp): def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]): return UOp(Ops.LOAD, dtypes.float.vec(4), ( - UOp(Ops.PARAM, dtypes.imagef(image_shape), arg=0).index(UOp(Ops.VECTORIZE, dtypes.index.vec(2), idx).valid(valid), ptr=True), + UOp(Ops.PARAM, dtypes.imagef(image_shape), arg=0).index(UOp(Ops.VECTORIZE, dtypes.weakint.vec(2), idx).valid(valid), ptr=True), UOp(Ops.VECTORIZE, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4) )) -def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, nmax),), expr) +def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.weakint, (UOp.const(dtypes.weakint, nmax),), expr) def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax) def Range(n, nmax): return UOp.range(nmax, n) @@ -478,7 +478,7 @@ class TestDropTrueGate(unittest.TestCase): from tinygrad.codegen.late.devectorizer import load_store_indexing from tinygrad.uop.ops import graph_rewrite buf = UOp(Ops.PARAM, dtypes.int.ptr(), arg=0) - idx = UOp.const(dtypes.index, 0) + idx = UOp.const(dtypes.weakint, 0) true_gate = UOp.const(dtypes.bool, True) index_with_gate = UOp(Ops.INDEX, dtypes.int.ptr(), (buf, idx, true_gate)) # apply the optimization @@ -495,7 +495,7 @@ class TestRangeShrink(unittest.TestCase): def test_range_shrink_single_guard(self): # range 0..203 guarded by r < 4 everywhere -> shrink to 0..3 r = Range(0, 204) - load = get_gated_load_uop(r < UOp.const(dtypes.index, 4), r) + load = get_gated_load_uop(r < UOp.const(dtypes.weakint, 4), r) ranges = self.get_ranges(load.sink()) self.assertEqual(len(ranges), 1) self.assertEqual(ranges[0].src[0].arg, 4) @@ -503,8 +503,8 @@ class TestRangeShrink(unittest.TestCase): def test_range_shrink_picks_max_guard(self): # two loads guard the same range with r < 4 and r < 8 -> shrink to max(4, 8) = 8 r = Range(0, 204) - load1 = get_gated_load_uop(r < UOp.const(dtypes.index, 4), r) - load2 = get_gated_load_uop(r < UOp.const(dtypes.index, 8), r) + load1 = get_gated_load_uop(r < UOp.const(dtypes.weakint, 4), r) + load2 = get_gated_load_uop(r < UOp.const(dtypes.weakint, 8), r) ranges = self.get_ranges(UOp.sink(load1, load2)) self.assertEqual(len(ranges), 1) self.assertEqual(ranges[0].src[0].arg, 8) @@ -512,7 +512,7 @@ class TestRangeShrink(unittest.TestCase): def test_range_no_shrink_guard_ge_max(self): # guard r < 300 with range max 204 -> no shrink (guard doesn't constrain) r = Range(0, 204) - load = get_gated_load_uop(r < UOp.const(dtypes.index, 300), r) + load = get_gated_load_uop(r < UOp.const(dtypes.weakint, 300), r) ranges = self.get_ranges(load.sink()) self.assertEqual(len(ranges), 1) self.assertEqual(ranges[0].src[0].arg, 204) @@ -520,7 +520,7 @@ class TestRangeShrink(unittest.TestCase): def test_range_no_shrink_when_unguarded_elsewhere(self): # one load guards r < 4, but another load uses r without a gate -> no shrink r = Range(0, 204) - load1 = get_gated_load_uop(r < UOp.const(dtypes.index, 4), r) + load1 = get_gated_load_uop(r < UOp.const(dtypes.weakint, 4), r) load2 = UOp(Ops.LOAD, dtypes.float, (UOp(Ops.PARAM, dtypes.float.ptr(), arg=1).index(r, ptr=True),)) ranges = self.get_ranges(UOp.sink(load1, load2)) self.assertEqual(len(ranges), 1) @@ -529,7 +529,7 @@ class TestRangeShrink(unittest.TestCase): def test_range_no_shrink_when_used_in_reduce(self): # range used in both a gated load AND directly in the reduce expression -> no shrink r = Range(0, 204) - gated_load = get_gated_load_uop(r < UOp.const(dtypes.index, 4), r) + gated_load = get_gated_load_uop(r < UOp.const(dtypes.weakint, 4), r) red = UOp(Ops.REDUCE, dtypes.float, (r.cast(dtypes.float) + gated_load, r), Ops.ADD) ranges = self.get_ranges(red.sink()) self.assertEqual(len(ranges), 1) @@ -538,7 +538,7 @@ class TestRangeShrink(unittest.TestCase): def test_range_shrink_to_single_iteration(self): # guard r < 1 shrinks range to 1 -> single iteration, range eliminated entirely r = Range(0, 204) - load = get_gated_load_uop(r < UOp.const(dtypes.index, 1), r) + load = get_gated_load_uop(r < UOp.const(dtypes.weakint, 1), r) ranges = self.get_ranges(load.sink()) self.assertEqual(len(ranges), 0) diff --git a/test/null/test_uop_graph.py b/test/null/test_uop_graph.py index 95d7742e9f..b6aa0116f0 100644 --- a/test/null/test_uop_graph.py +++ b/test/null/test_uop_graph.py @@ -205,7 +205,7 @@ class TestUOpGraph(unittest.TestCase): def test_where_same_fold(self): v = UOp.variable('tmp', 0, 1) - c0 = UOp.const(dtypes.index, 0) + c0 = UOp.const(dtypes.weakint, 0) vc = UOp(Ops.CMPNE, dtypes.bool, (v, c0)) c1 = UOp.const(dtypes.float, 1.0) out = UOp(Ops.WHERE, dtypes.float, (vc, c1, c1)) @@ -469,16 +469,16 @@ class TestUOpGraph(unittest.TestCase): # mnist indexing with split reduceop # Make sure we are not doign math on the loaded index, which would promote it to long c0 = UOp(Ops.PARAM, dtypes.uchar.ptr(128000), arg=0, src=()) - c1 = UOp.range(UOp.const(dtypes.index, 512), 1, AxisType.LOOP) - c2 = UOp.range(UOp.const(dtypes.index, 250), 2, AxisType.LOOP) + c1 = UOp.range(UOp.const(dtypes.weakint, 512), 1, AxisType.LOOP) + c2 = UOp.range(UOp.const(dtypes.weakint, 250), 2, AxisType.LOOP) c3 = UOp(Ops.PARAM, dtypes.int.ptr(512), arg=1, src=()) c4 = c3.index(c1) - c5 = UOp.range(UOp.const(dtypes.index, 240), 0, AxisType.REDUCE) - c6 = ((c2*UOp.const(dtypes.index, 240))+c5) + c5 = UOp.range(UOp.const(dtypes.weakint, 240), 0, AxisType.REDUCE) + c6 = ((c2*UOp.const(dtypes.weakint, 240))+c5) c7 = UOp(Ops.PARAM, dtypes.uchar.ptr(60000), arg=2, src=()) c8 = c7.index(c6) c9 = ((c4<0).where((c4+60000), c4)!=c6.cast(dtypes.int)).where(0, c8.cast(dtypes.uint).cast(dtypes.uchar)).reduce(c5, arg=Ops.ADD) - c10 = c0.index(((c1*UOp.const(dtypes.index, 250))+c2)).store(c9).end(c1, c2) + c10 = c0.index(((c1*UOp.const(dtypes.weakint, 250))+c2)).store(c9).end(c1, c2) uops = to_uops_list([c10]) for u in uops: self.assertNotEqual(u.dtype, dtypes.long) @@ -486,19 +486,19 @@ class TestUOpGraph(unittest.TestCase): def test_load_idx_no_math_on_loaded(self): # test the (x+y) NOOP rule. This rule matches patterns that EMERGE during simplification.""" def test_store_load_folding(self): # store(idx, load(idx)) -> NOOP, including emergent patterns like store(idx, load(idx) + 0) buf = UOp(Ops.PARAM, dtypes.int.ptr(), arg=0) - index = buf.index(UOp.const(dtypes.index, 0)) + index = buf.index(UOp.const(dtypes.weakint, 0)) # Direct: store(idx, load(idx)) -> NOOP self.assertEqual(graph_rewrite(index.store(index.load()), sym).op, Ops.NOOP) # Emergent: store(idx, load(idx) + 0) -> store(idx, load(idx)) -> NOOP @@ -1250,10 +1250,10 @@ class TestGatedUopGivenValid(unittest.TestCase): idx0 = (r0 + uconst(-1)) // uconst(3) idx1 = r0 % uconst(3) - idx:UOp = (r0 < 3).where(UOp(Ops.VECTORIZE, dtypes.index.vec(2), (idx0, idx1)), UOp.invalid()) + idx:UOp = (r0 < 3).where(UOp(Ops.VECTORIZE, dtypes.weakint.vec(2), (idx0, idx1)), UOp.invalid()) idx = graph_rewrite(idx, pm_simplify_valid) # NOTE: independent simplification: (r0-1)//3 -> 0, r0%3 -> r0 when r0 in [0,2] - expected_vec = UOp(Ops.VECTORIZE, dtypes.index.vec(2), (uconst(0), r0)) + expected_vec = UOp(Ops.VECTORIZE, dtypes.weakint.vec(2), (uconst(0), r0)) self.assertEqual(idx, (r0 < 3).where(expected_vec, UOp.invalid())) class TestRangeSplitting(unittest.TestCase): diff --git a/test/null/test_uop_vmin_vmax.py b/test/null/test_uop_vmin_vmax.py index 9b307a813b..5b27b76dca 100644 --- a/test/null/test_uop_vmin_vmax.py +++ b/test/null/test_uop_vmin_vmax.py @@ -144,7 +144,7 @@ class TestVminVmaxProperties(unittest.TestCase): self.assertNotEqual(i.vmin, i.vmax) def test_vmin_vmax_invalid_vconst(self): - x = UOp.const(dtypes.index.vec(4), (0, 4, Invalid, Invalid)) + x = UOp.const(dtypes.weakint.vec(4), (0, 4, Invalid, Invalid)) self.assertLess(x.vmin, 0) self.assertGreater(x.vmax, 4) diff --git a/test/null/test_validate_oob.py b/test/null/test_validate_oob.py index aea9d2d54f..a694b81cd6 100644 --- a/test/null/test_validate_oob.py +++ b/test/null/test_validate_oob.py @@ -126,7 +126,7 @@ class TestValidateOOB(unittest.TestCase): buf0 = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0) buf1 = UOp(Ops.PARAM, dtypes.int.ptr(64), (), 1) r = UOp.range(42, 0, AxisType.GLOBAL) - ld0 = buf0.index(r.valid(r < 8), ptr=True).load(dtype=dtypes.int).cast(dtypes.index) + ld0 = buf0.index(r.valid(r < 8), ptr=True).load(dtype=dtypes.int).cast(dtypes.weakint) to_uops_list([buf1.index((ld0 * 2).valid((ld0 >= 0) & (ld0 < 32)), ptr=True).load(dtype=dtypes.int)]) # valid with self.assertRaises(RuntimeError): to_uops_list([buf1.index((ld0 * 2).valid((ld0 >= 0) & (ld0 < 64)), ptr=True).load(dtype=dtypes.int)]) # oob @@ -135,7 +135,7 @@ class TestValidateOOB(unittest.TestCase): with Context(CHECK_OOB=1, SPEC=2): buf_bool = UOp(Ops.PARAM, dtypes.bool.ptr(16), (), 0) buf_int = UOp(Ops.PARAM, dtypes.int.ptr(8), (), 1) - gidx = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 16),), "gidx0") + gidx = UOp(Ops.SPECIAL, dtypes.weakint, (UOp.const(dtypes.weakint, 16),), "gidx0") ld_bool = buf_bool.index(gidx, ptr=True).load() with self.assertRaises(RuntimeError): to_uops_list([buf_int.index(gidx.valid(ld_bool), ptr=True).load()]) # gidx 0..15, buf_int size 8 diff --git a/tinygrad/codegen/gpudims.py b/tinygrad/codegen/gpudims.py index dc37a1d503..e4f0f440d6 100644 --- a/tinygrad/codegen/gpudims.py +++ b/tinygrad/codegen/gpudims.py @@ -35,7 +35,7 @@ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|No if len(limited) > len(max_sizes): raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}") # try to split up dims: (a,) -> (b, c) if limited == dims: limited = _split_dims(dims, max_sizes) - raw_idxs = [UOp(Ops.SPECIAL, dtypes.index, (sint_to_uop(s),), (f"{prefix}{i}")) for i,s in enumerate(limited)] + raw_idxs = [UOp(Ops.SPECIAL, dtypes.weakint, (sint_to_uop(s),), (f"{prefix}{i}")) for i,s in enumerate(limited)] if len(limited) < len(dims): ret = [] if (contraction:=get_contraction(dims, limited)) is None: raise RuntimeError(f"get_contraction should not be None {dims=} {limited=}") @@ -75,7 +75,7 @@ def add_gpudims(ctx:Renderer, s:UOp): # get the idxs ki: KernelInfo = s.arg - if ctx.has_threads: idxs = [UOp.variable("core_id", 0, int(global_shape[0])-1, dtypes.int).cast(dtypes.index)] + if ctx.has_threads: idxs = [UOp.variable("core_id", 0, int(global_shape[0])-1, dtypes.int).cast(dtypes.weakint)] elif ki.dont_use_locals: assert not local_dims, "can't use locals if there's no local dims" idxs = get_grouped_dims("idx", global_shape, ctx.global_max, reverse=True) diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index e8bc757f5e..916bbc79aa 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -197,7 +197,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp): return UOp(Ops.VCAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp.group(*ret) def get_image_idx(idx:UOp, width:int): - oidx = UOp(Ops.VECTORIZE, dtypes.index.vec(2), (((x:=idx.src[1].get_idx()) // 4) % width, (x // (4*width)))) + oidx = UOp(Ops.VECTORIZE, dtypes.weakint.vec(2), (((x:=idx.src[1].get_idx()) // 4) % width, (x // (4*width)))) return idx.replace(src=(idx.src[0], oidx.valid(idx.src[1].get_valid()))) def image_fixup(ls:UOp): @@ -208,7 +208,7 @@ def image_fixup(ls:UOp): return ls.replace(src=(idx,)+ls.src[1:]) # this is an unprocessed image without a cast, aka unfoldable image load. this doesn't work for stores - if isinstance(image_dtype:=ls.src[0].dtype, ImageDType) and ls.src[0].src[1].get_idx().dtype != dtypes.index.vec(2): + if isinstance(image_dtype:=ls.src[0].dtype, ImageDType) and ls.src[0].src[1].get_idx().dtype != dtypes.weakint.vec(2): assert ls.op is Ops.LOAD, "if an image store isn't upcasted to 4, we can't store it" x, idx = ls.src[0].src[1].get_idx(), get_image_idx(ls.src[0], image_dtype.shape[1]) vec_load = ls.replace(dtype=ls.dtype.vec(4), src=(idx,)+ls.src[1:]) @@ -263,7 +263,7 @@ def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp, bcast:UOp|None=None): # simple scalar index: one lane, all components pairs = [(0, c) for c in range(cnt)] idx_lanes, offsets = (tuple(x) for x in zip(*pairs)) - return buf.broadcast(len(pairs)).index(idx.gep(idx_lanes)*cnt + UOp.const(dtypes.index.vec(len(pairs)), offsets), ptr=True) + return buf.broadcast(len(pairs)).index(idx.gep(idx_lanes)*cnt + UOp.const(dtypes.weakint.vec(len(pairs)), offsets), ptr=True) devectorize_buf_and_index = PatternMatcher([ (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf), diff --git a/tinygrad/codegen/simplify.py b/tinygrad/codegen/simplify.py index fd03d39c44..3bd7041624 100644 --- a/tinygrad/codegen/simplify.py +++ b/tinygrad/codegen/simplify.py @@ -157,5 +157,5 @@ def no_load(u:UOp) -> bool: return not any(x.op is Ops.INDEX for x in u.backward pm_load_collapse = PatternMatcher([ (UPat(Ops.REDUCE, arg=Ops.ADD, src=(UPat.var("u"), UPat()), name="red"), reduce_load_collapse), # we want to make sure we dont do math on a loaded index since that can cause overflow, this undoes the rule in pm_reduce_load_collapse - ((UPat.var("x", dtypes.index)+UPat.var("y")) bool: - if dtype == dtypes.index: return False + if dtype == dtypes.weakint: return False if device is None: device = Device.DEFAULT if dtype == dtypes.bfloat16: if device == "METAL": return not CI diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 5dce3c3e5b..af8b9ae6db 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -159,7 +159,7 @@ class dtypes: def is_float(x: DType) -> bool: return x.scalar() in dtypes.floats or isinstance(x, ImageDType) @staticmethod # static methods on top, or bool in the type info will refer to dtypes.bool @functools.cache - def is_int(x: DType) -> bool: return x.scalar() in (dtypes.ints + (dtypes.index,)) + def is_int(x: DType) -> bool: return x.scalar() in (dtypes.ints + (dtypes.weakint,)) @staticmethod @functools.cache def is_unsigned(x: DType) -> bool: return x.scalar() in dtypes.uints @@ -180,7 +180,7 @@ class dtypes: return {dtypes.float16: (5, 10), dtypes.bfloat16: (8, 7), dtypes.float32: (8, 23), dtypes.float64: (11, 52), dtypes.fp8e4m3: (4, 3), dtypes.fp8e5m2: (5, 2), dtypes.fp8e4m3fnuz: (4, 3), dtypes.fp8e5m2fnuz: (5, 2)}[dtype] void: Final[DType] = DType.new(-1, 0, "void", None) - index: Final[DType] = DType.new(-1, 800, "index", None) + weakint: Final[DType] = DType.new(-1, 800, "weakint", None) bool: Final[DType] = DType.new(0, 1, "bool", '?') int8: Final[DType] = DType.new(1, 8, "signed char", 'b') uint8: Final[DType] = DType.new(2, 8, "unsigned char", 'B') @@ -227,7 +227,7 @@ class dtypes: uints = (uint8, uint16, uint32, uint64) sints = (int8, int16, int32, int64) ints = uints + sints - all = floats + ints + (bool, index) # noqa: A003 + all = floats + ints + (bool, weakint) # noqa: A003 if (env_default_float := getenv("DEFAULT_FLOAT", "")): dtypes.default_float = getattr(dtypes, env_default_float.lower()) @@ -237,7 +237,8 @@ DTypeLike = str|DType def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) else getattr(dtypes, dtype.lower()) # https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html -# we don't support weak type and complex type +# we don't support complex type +# TODO: weakint and weakfloat in lattice promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64], dtypes.int64: [dtypes.uint64], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32], dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.fp8e4m3, dtypes.fp8e5m2, dtypes.fp8e4m3fnuz, dtypes.fp8e5m2fnuz], @@ -254,8 +255,8 @@ def least_upper_dtype(*ds:DType) -> DType: if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0] def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.default_float) -DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void", "index", "_"))} -INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void", "index":"index"} +DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void", "weakint", "_"))} +INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void", "weakint":"weakint"} @functools.cache def can_lossless_cast(dt0:DType, dt1:DType) -> bool: @@ -263,7 +264,7 @@ def can_lossless_cast(dt0:DType, dt1:DType) -> bool: # similar to https://numpy.org/doc/stable/reference/generated/numpy.can_cast.html if dt0 == dt1 or dt0 == dtypes.bool: return True match dt1: - case dtypes.index: return dt0 in dtypes.ints + case dtypes.weakint: return dt0 in dtypes.ints case dtypes.double: return dt0 in (dtypes.float, dtypes.half, dtypes.bfloat16, *dtypes.fp8s, dtypes.uint32, dtypes.uint16, dtypes.uint8, dtypes.int32, dtypes.int16, dtypes.int8) case dtypes.float: return dt0 in (dtypes.half, dtypes.bfloat16, *dtypes.fp8s, dtypes.uint16, dtypes.uint8, dtypes.int16, dtypes.int8) diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index f7ec524346..7e3c500f0f 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -347,12 +347,12 @@ def _embedding_bwd(grad_emb:UOp, call:UOp) -> tuple: # each device owns [offset, offset+local_vocab_size) of the global vocabulary dnum = UOp.variable("_device_num", 0, ndev-1) offset = dnum * local_vocab_size - global_token_id = idx_flat[i].cast(dtypes.index) + global_token_id = idx_flat[i].cast(dtypes.weakint) local_token_id = (global_token_id - offset).clip(0, grad_weight.shape[0]-1) in_range = (global_token_id >= offset) & (global_token_id < (offset + local_vocab_size)) grad_val = in_range.where(grad_emb_flat[i, j].cast(dtypes.float), 0.0) else: - local_token_id = idx_flat[i].clip(0, grad_weight.shape[0]-1).cast(dtypes.index) + local_token_id = idx_flat[i].clip(0, grad_weight.shape[0]-1).cast(dtypes.weakint) grad_val = grad_emb_flat[i, j].cast(dtypes.float) # atomic scatter-add: grad_weight[token_id, j] += grad_emb_flat[i, j] if device in ("CPU", "NULL"): atomic_arg = "__atomic_fetch_add({0}, {1}, __ATOMIC_RELAXED);" diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index 0b46ccf1c5..6cbd8b3920 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -52,8 +52,8 @@ class IndexingContext: range_idx: Iterator[int] = field(default_factory=itertools.count) def new_range(self, s:sint, axistype:AxisType=AxisType.LOOP) -> UOp: if isinstance(s, UOp) and s.op is Ops.RANGE: return s - # if a range has a 1 src, it's the same as UOp.const(dtypes.index, 0) - return UOp.range(s, next(self.range_idx), axistype) if resolve(s!=1) else UOp.const(dtypes.index, 0) + # if a range has a 1 src, it's the same as UOp.const(dtypes.weakint, 0) + return UOp.range(s, next(self.range_idx), axistype) if resolve(s!=1) else UOp.const(dtypes.weakint, 0) def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp): if x.op in {Ops.BUFFERIZE, Ops.INDEX}: return None @@ -118,7 +118,7 @@ def _apply_reshape(in_shape:tuple[sint,...], out_shape:tuple[sint, ...], urngs:U for s,src in list(zip(out_shape, urngs.src))[::-1]: axes_in.append(acc*src) acc *= s - combined_axes = UOp.const(dtypes.index, 0).sum(*axes_in) + combined_axes = UOp.const(dtypes.weakint, 0).sum(*axes_in) axes_out:list[UOp] = [] for s in in_shape[::-1]: axes_out.append(combined_axes % s) @@ -172,7 +172,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]: # treat MSTACK/MSELECT like SINK if x.op in {Ops.MSTACK, Ops.MSELECT}: continue - if x.dtype.scalar() == dtypes.index: continue # TODO: why do I need this? + if x.dtype.scalar() == dtypes.weakint: continue # TODO: why do I need this? ending_ranges[x] = sum([ending_ranges.get(u, []) for u in consumer_map[x]], []) # *** the ranges on the output are diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 46aa5923f9..2370a02110 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -94,7 +94,7 @@ def split_reduceop(reduce:UOp, x:UOp): # split is moved to the end to provide maximum locality for the second phase reduce. # get expanded by rangeifying the UOp x - indexed = x.index(*[UOp.range(s, i) if resolve(s>1) else UOp.const(dtypes.index, 0) for i,s in enumerate(x.shape)]) + indexed = x.index(*[UOp.range(s, i) if resolve(s>1) else UOp.const(dtypes.weakint, 0) for i,s in enumerate(x.shape)]) range_nums = [y.arg[0] for y in indexed.substitute({x.base:UOp(Ops.NOOP)}, extra_pm=pm_mops).ranges] is_expanded = [i not in range_nums for i in range(len(x.shape))] diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 52c6a24627..8157a0bc64 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -132,9 +132,10 @@ class Tensor(OpMixin): # create a UOp from the different types of inputs if isinstance(data, UOp): - assert _dtype is None or _dtype==data.dtype or data.dtype==dtypes.index, f"dtype mismatch: {_dtype} vs {data.dtype}" - # if data is dtype.index that means that this is a symbolic int and we need to lower it to something we can make a Tensor out of - if data.dtype==dtypes.index: data = _index_to_concrete_int(data) + assert _dtype is None or _dtype==data.dtype or data.dtype==dtypes.weakint, f"dtype mismatch: {_dtype} vs {data.dtype}" + # if data is dtype.weakint that means that this is a symbolic int and we need to lower it to something we can make a Tensor out of + # TODO: remove this and stay in weakint + if data.dtype==dtypes.weakint: data = _index_to_concrete_int(data) if data.op is Ops.BIND: var, val = data.unbind() # give the bound constant a device diff --git a/tinygrad/uop/divandmod.py b/tinygrad/uop/divandmod.py index e5bdcb5d20..086b937b8d 100644 --- a/tinygrad/uop/divandmod.py +++ b/tinygrad/uop/divandmod.py @@ -107,17 +107,17 @@ div_and_mod_symbolic = PatternMatcher([ # ** 1. Fast Inline Rules ** ((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d) if c.vmin>0 and d.vmin>0 and x.vmin>=0 and a.vmin>=0 else None), # (x//c+a)//d -> (x+a*c)//(c*d) - (UPat.var("x", dtypes.index) // UPat.var("d"), lambda x,d: -(x//(-d)) if d.vmax < 0 else None), - (UPat.var("x", dtypes.index) // UPat.var("d"), lambda x,d: -((-x)//d) if x.vmax <= 0 else None), - ((UPat.var("x", dtypes.index)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False), + (UPat.var("x", dtypes.weakint) // UPat.var("d"), lambda x,d: -(x//(-d)) if d.vmax < 0 else None), + (UPat.var("x", dtypes.weakint) // UPat.var("d"), lambda x,d: -((-x)//d) if x.vmax <= 0 else None), + ((UPat.var("x", dtypes.weakint)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False), lambda x,c,n,d: ((x+c.arg%d.arg)//d + c.arg//d.arg) if c.arg%d.arg!=c.arg and x.vmin>=0 and n.vmin>=0 and d.arg>0 else None), - ((UPat.var("x", dtypes.index)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False), + ((UPat.var("x", dtypes.weakint)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False), lambda x,c,n,d: (-(-(c.arg%d.arg + x - (d.arg-1))//d) + c.arg//d.arg) if x.vmax<=0 and n.vmin>=0 and d.arg>0 else None), # ** 2. Slow Rules ** - (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d"), lambda d: fold_divmod_general(d, bool(CORRECT_DIVMOD_FOLDING))), + (UPat((Ops.IDIV, Ops.MOD), dtypes.weakint, name="d"), lambda d: fold_divmod_general(d, bool(CORRECT_DIVMOD_FOLDING))), # NOTE: these have to go at the bottom or TestSymbolicOps.test_var loops - (UPat.var("x", dtypes.index) % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <= 0 else None), - (UPat.var("x", dtypes.index) % UPat.var("d"), lambda x,d: (x%(-d)) if d.vmax < 0 else None), + (UPat.var("x", dtypes.weakint) % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <= 0 else None), + (UPat.var("x", dtypes.weakint) % UPat.var("d"), lambda x,d: (x%(-d)) if d.vmax < 0 else None), ]) \ No newline at end of file diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index fe66fc85e9..97848e2ae2 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -59,9 +59,9 @@ def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str: return ret def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp: - if len(arg) == 0: return UOp(Ops.VECTORIZE, dtypes.index.vec(0)) - elif all_int(arg): return UOp.const(dtypes.index.vec(len(arg)), arg) - else: return UOp(Ops.VECTORIZE, dtypes.index.vec(len(arg)), tuple(UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in arg)) + if len(arg) == 0: return UOp(Ops.VECTORIZE, dtypes.weakint.vec(0)) + elif all_int(arg): return UOp.const(dtypes.weakint.vec(len(arg)), arg) + else: return UOp(Ops.VECTORIZE, dtypes.weakint.vec(len(arg)), tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg)) def consumer_map_from_toposort(lst:Iterable[UOp]): ret: dict[UOp, dict[UOp, None]] = {} @@ -426,9 +426,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass): assert len(idx) == len(self.shape), f"__getitem__ shape mismatch, indexing {self.shape} with {len(idx)} args" if len(slice_idx:=[i for i,x in enumerate(idx) if isinstance(x, slice)]): perm = self.permute(tuple([i for i in range(self.ndim) if i not in slice_idx] + slice_idx)) - return perm.index(*[UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in idx if not isinstance(x, slice)], ptr=True) + return perm.index(*[UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in idx if not isinstance(x, slice)], ptr=True) else: - return self.index(*[UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in idx]) + return self.index(*[UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in idx]) def const_like(self, b:ConstLike): # constants can optionally have a DEVICE source return UOp.const(self.dtype, b, device=self._device, shape=self._shape) @@ -482,21 +482,21 @@ class UOp(OpMixin, metaclass=UOpMetaClass): ret = ret.replace(src=(UOp.unique(None if unique is True else unique),) + ret.src) return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None else ret @staticmethod - def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.index, src=(), **kwargs): + def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.weakint, src=(), **kwargs): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(axis_id, axis_type)+arg, **kwargs) @staticmethod - def special(end:sint, name:str, dtype=dtypes.index): return UOp(Ops.SPECIAL, dtype=dtype, src=(sint_to_uop(end, dtype),), arg=name) + def special(end:sint, name:str, dtype=dtypes.weakint): return UOp(Ops.SPECIAL, dtype=dtype, src=(sint_to_uop(end, dtype),), arg=name) def r(self, op:Ops, axis:tuple[int, ...]): axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)])) return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis)) if len(axis) else self @staticmethod - def invalid(count=1): return UOp(Ops.CONST, dtypes.index.vec(count), src=(), arg=Invalid) + def invalid(count=1): return UOp(Ops.CONST, dtypes.weakint.vec(count), src=(), arg=Invalid) def valid(self, cond): return self if cond.op is Ops.WHERE and cond.arg else cond.where(self, UOp.invalid(self.dtype.count)) def get_idx(self) -> UOp: - assert self.dtype.scalar() is dtypes.index, "Can only call get_idx on index dtype" + assert self.dtype.scalar() is dtypes.weakint, "Can only call get_idx on index dtype" return self.src[1] if self.op is Ops.WHERE and self.src[2].arg is Invalid else self def get_valid(self) -> UOp: - assert self.dtype.scalar() is dtypes.index, "Can only call get_valid on index dtype" + assert self.dtype.scalar() is dtypes.weakint, "Can only call get_valid on index dtype" return self.src[0] if self.op is Ops.WHERE and self.src[2].arg is Invalid else UOp.const(dtypes.bool, self.arg is not Invalid) def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs) @@ -757,7 +757,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): # *** uop Variable stuff *** @staticmethod - def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.index) -> UOp: + def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.weakint) -> UOp: assert not isinstance(min_val, UOp) and not isinstance(max_val, UOp), f"can't create Variable {name} with {min_val}/{max_val}" return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val)) @property @@ -866,7 +866,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): if self.op is Ops.VCONST and Invalid not in self.arg: return (min(self.arg), max(self.arg)) if self.op is Ops.GEP: return self.src[0]._min_max # TODO: CAST to bool/unsigned is not monotone, still some case can be simplified - if self.op is Ops.CAST and self.dtype in dtypes.floats+dtypes.sints+(dtypes.index,): + if self.op is Ops.CAST and self.dtype in dtypes.floats+dtypes.sints+(dtypes.weakint,): return max(self.dtype.min, self.src[0].vmin), min(self.src[0].vmax, self.dtype.max) return self.dtype.min, self.dtype.max @@ -991,7 +991,7 @@ python_alu: dict[Ops, Callable] = { def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True): if dtype.count > 1: return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)]) - if dtype==dtypes.index and op in GroupOp.Binary and Invalid in operands: return Invalid + if dtype==dtypes.weakint and op in GroupOp.Binary and Invalid in operands: return Invalid alu = python_alu[op](*operands) return truncate.get(dtype, lambda x: x)(alu) if truncate_output else alu @@ -1424,23 +1424,26 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=N rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx, enter_calls) return rewrite_ctx.walk_rewrite(sink) if walk else rewrite_ctx.unified_rewrite(sink) -def sint_to_uop(x:sint, dtype=dtypes.index) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x.cast(dtype) +def sint_to_uop(x:sint, dtype=dtypes.weakint) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x.cast(dtype) def select_dtype(u): return (dtypes.long if u.overflows(dtypes.int32) else dtypes.int).vec(u.dtype.count) pm_lower_index_dtype = PatternMatcher([ # There are no Unary ops at this point in symbolic, those are introduced later - (UPat(GroupOp.Binary, name="u", src=(UPat.var("x").cast(dtypes.index), UPat.var("y").cast(dtypes.index))), lambda u,x,y: + (UPat(GroupOp.Binary, name="u", src=(UPat.var("x").cast(dtypes.weakint), UPat.var("y").cast(dtypes.weakint))), lambda u,x,y: x.cast(dt:=least_upper_dtype(select_dtype(u), x.dtype, y.dtype)).alu(u.op, y.cast(dt)).cast(u.dtype)), - (UPat((Ops.CONST, Ops.VCONST), dtype=dtypes.index, name="u"), lambda u: u.replace(dtype=select_dtype(u)).cast(u.dtype) if u.arg!=Invalid else None), - (UPat(Ops.WHERE, dtypes.index, src=(UPat.var("cond"), UPat.var("x").cast(dtypes.index), UPat.var("y").cast(dtypes.index))), lambda cond,x,y: - cond.where(x.cast(dt:=least_upper_dtype(x.dtype, y.dtype)), y.cast(dt)).cast(dtypes.index)), - (UPat(Ops.RANGE, src=(UPat.var("end").cast(dtypes.index)), name="r"), lambda r,end: r.replace(dtype=end.dtype, src=(end,)).cast(dtypes.index)), - (UPat(Ops.VECTORIZE, src=UPat().cast(dtypes.index), name="v"), - lambda v: v.replace(dtype=(dt:=select_dtype(v)), src=tuple(s.src[0].cast(dt.scalar()) for s in v.src)).cast(dtypes.index)), + (UPat((Ops.CONST, Ops.VCONST), dtype=dtypes.weakint, name="u"), + lambda u: u.replace(dtype=select_dtype(u)).cast(u.dtype) if u.arg!=Invalid else None), + (UPat(Ops.WHERE, dtypes.weakint, src=(UPat.var("cond"), UPat.var("x").cast(dtypes.weakint), UPat.var("y").cast(dtypes.weakint))), lambda cond,x,y: + cond.where(x.cast(dt:=least_upper_dtype(x.dtype, y.dtype)), y.cast(dt)).cast(dtypes.weakint)), + (UPat(Ops.RANGE, src=(UPat.var("end").cast(dtypes.weakint)), name="r"), lambda r,end: r.replace(dtype=end.dtype, src=(end,)).cast(dtypes.weakint)), + (UPat(Ops.VECTORIZE, src=UPat().cast(dtypes.weakint), name="v"), + lambda v: v.replace(dtype=(dt:=select_dtype(v)), src=tuple(s.src[0].cast(dt.scalar()) for s in v.src)).cast(dtypes.weakint)), # special can only be int32 - (UPat(Ops.SPECIAL, src=(UPat.var("var").cast(dtypes.index),), name="u"), lambda u,var: u.replace(dtype=dtypes.int, src=(var,)).cast(dtypes.index)), - (UPat(Ops.DEFINE_VAR, dtype=dtypes.index, name="u"), lambda u: u.replace(dtype=dtypes.int).cast(dtypes.index)), - (UPat(Ops.BIND, src=(UPat.var("var").cast(dtypes.index), UPat.cvar("val").cast(dtypes.index))), lambda var,val: var.bind(val).cast(dtypes.index)), + (UPat(Ops.SPECIAL, src=(UPat.var("var").cast(dtypes.weakint),), name="u"), + lambda u,var: u.replace(dtype=dtypes.int, src=(var,)).cast(dtypes.weakint)), + (UPat(Ops.DEFINE_VAR, dtype=dtypes.weakint, name="u"), lambda u: u.replace(dtype=dtypes.int).cast(dtypes.weakint)), + (UPat(Ops.BIND, src=(UPat.var("var").cast(dtypes.weakint), UPat.cvar("val").cast(dtypes.weakint))), + lambda var,val: var.bind(val).cast(dtypes.weakint)), # lower Invalid (UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"), UPat(Ops.CONST, arg=Invalid))), lambda buf,idx,cond: buf.index(idx, cond, ptr=True)), # remove hanging casts @@ -1448,7 +1451,7 @@ pm_lower_index_dtype = PatternMatcher([ (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid, ptr=True)), (UPat((Ops.SINK, Ops.NOOP, Ops.END), name="n"), - lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.index else s for s in n.src))), + lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.weakint else s for s in n.src))), ]) def _index_to_concrete_int(u:UOp) -> UOp: return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0] @@ -1535,7 +1538,7 @@ pm_pyrender_extra = PatternMatcher([ (UPat(Ops.CONST, src=(UPat(Ops.DEVICE, name="d"),), name="x"), lambda x,d: f"UOp.const({x.dtype}, {x.arg}, device={repr(d.arg)})"), (UPat(Ops.CONST, name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"), (UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x: - f"UOp.variable(\"{x.arg[0]}\", {x.arg[1]}, {x.arg[2]}{', dtype='+str(x.dtype) if x.dtype is not dtypes.index else ''})"), + f"UOp.variable(\"{x.arg[0]}\", {x.arg[1]}, {x.arg[2]}{', dtype='+str(x.dtype) if x.dtype is not dtypes.weakint else ''})"), (UPat((Ops.CAST, Ops.BITCAST), name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({x.dtype})"), (UPat(Ops.SPECIAL, src=(UPat(Ops.CONST),), name="x"), lambda x: f"UOp.special({x.src[0].arg}, {repr(x.arg)}, dtype={x.dtype})"), (UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), name="x"), lambda x,u,d: @@ -1546,7 +1549,7 @@ pm_pyrender_extra = PatternMatcher([ # NOTE: range has srcs sometimes after control flow (UPat(Ops.RANGE, src=(UPat(Ops.CONST, name="c"),), allow_any_len=True, name="x"), lambda ctx,x,c: "UOp.range("+', '.join([str(c.arg)] + [repr(y) for y in x.arg])+ - (f', src={srcs(ctx, x.src[1:])}' if len(x.src) > 1 else '')+(', dtype='+str(x.dtype) if x.dtype is not dtypes.index else '')+")"), + (f', src={srcs(ctx, x.src[1:])}' if len(x.src) > 1 else '')+(', dtype='+str(x.dtype) if x.dtype is not dtypes.weakint else '')+")"), # TODO: index shouldn't mismatch dtype (UPat(Ops.INDEX, src=(UPat(), UPat()), allow_any_len=True, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.index({ctx[x.src[1]]}, "+''.join([f"{ctx[xx]}, " for xx in x.src[2:]])+ diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index bea993ed48..b545fc286a 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -54,7 +54,7 @@ shared_spec = PatternMatcher([ (UPat(Ops.RANGE, src=(UPat.var("x"),), allow_any_len=True, name="rng"), lambda rng,x: rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \ all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)), - (UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None), + (UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.weakint for y in x.src[1:]) or None), # RANGE/SPECIAL define loops, END closes them (UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE))), lambda: True), @@ -69,13 +69,13 @@ shared_spec = PatternMatcher([ # ***** UOp spec in the Tensor graph ***** movement_ops = PatternMatcher([ - (UPat((Ops.RESHAPE, Ops.EXPAND), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True), - (UPat((Ops.PAD, Ops.SHRINK), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index), UPat(dtype=dtypes.index))), lambda mv,x: True), + (UPat((Ops.RESHAPE, Ops.EXPAND), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.weakint))), lambda mv,x: True), + (UPat((Ops.PAD, Ops.SHRINK), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.weakint), UPat(dtype=dtypes.weakint))), lambda mv,x: True), (UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat.var("x"),)), lambda mv,x: isinstance(mv.arg, tuple)), # inputs to movement ops - (UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.index), lambda: True), - (UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.index), lambda: True), + (UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.weakint), lambda: True), + (UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.weakint), lambda: True), # AFTER on Movement Op, BUFFER, COPY, or BITCAST (UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI, Ops.CONTIGUOUS, Ops.BUFFER, Ops.BITCAST, Ops.COPY})),), allow_any_len=True), @@ -104,9 +104,9 @@ _tensor_spec = PatternMatcher([ (UPat(Ops.MSTACK, name="x"), lambda x: all(isinstance(x.device, str) for x in x.src)), # Tensor variable bindings - (UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.index,))), arg=None), lambda: True), + (UPat(Ops.BIND, (dtypes.int,dtypes.weakint,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.weakint,))), arg=None), lambda: True), # single-src BIND used for schedule cache key normalization - (UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR),), arg=None), lambda: True), + (UPat(Ops.BIND, (dtypes.int,dtypes.weakint,), (UPat(Ops.DEFINE_VAR),), arg=None), lambda: True), # device or unique (UPat(Ops.CONST, src=(UPat(Ops.DEVICE),)), lambda: True), @@ -189,7 +189,7 @@ shared_codegen_spec = PatternMatcher([ (UPat(Ops.INDEX, src=(UPat(GroupOp.Defines|{Ops.AFTER}, name="buf"), UPat.var("idx"), UPat.var("gate", dtype=dtypes.bool))), validate_index), # SPECIAL - (UPat(Ops.SPECIAL, src=(UPat.var("x", (dtypes.index, dtypes.int32)),), name="s"), lambda s,x: s.dtype == x.dtype and isinstance(s.arg, str)), + (UPat(Ops.SPECIAL, src=(UPat.var("x", (dtypes.weakint, dtypes.int32)),), name="s"), lambda s,x: s.dtype == x.dtype and isinstance(s.arg, str)), # BARRIER (on any length) (UPat(Ops.BARRIER, dtypes.void), lambda: True), @@ -199,7 +199,7 @@ shared_codegen_spec = PatternMatcher([ kernel_spec = PatternMatcher([ # index is allowed here - (UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.index), lambda: True), + (UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.weakint), lambda: True), # UNROLL/CONTRACT is used here for WMMA (UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)), @@ -212,7 +212,7 @@ kernel_spec = PatternMatcher([ (UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: True), # reduce must be on ranges - (UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype in (dtypes.index, dtypes.int) for y in x.src[1:])), + (UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype in (dtypes.weakint, dtypes.int) for y in x.src[1:])), # COPY/BUFFER_VIEW can have ranges appended (UPat(Ops.COPY, name="x", src=(UPat.var("s"), UPat(Ops.DEVICE)), allow_any_len=True, arg=None), @@ -236,7 +236,7 @@ program_spec = PatternMatcher([ (UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True), # make sure all index dtypes have been lowered (except CONST/RANGE/DEFINE_VAR which are valid index-typed) - (UPat(GroupOp.All-{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR, Ops.VCONST, Ops.VECTORIZE}, dtype=dtypes.index), lambda: False), + (UPat(GroupOp.All-{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR, Ops.VCONST, Ops.VECTORIZE}, dtype=dtypes.weakint), lambda: False), (UPat(Ops.CONST, arg=Invalid), lambda: False), (UPat(Ops.VCONST, name="x"), lambda x: all(v is not Invalid for v in x.arg) and len(x.arg)==x.dtype.vcount>1 and type(x.arg) is type(x.dtype.const(x.arg))), @@ -273,13 +273,13 @@ full_spec = PatternMatcher([ (UPat(Ops.CALL, dtype=dtypes.void), lambda: True), # where on index in rhs position is fine - (UPat(Ops.WHERE, dtype=dtypes.index, src=(UPat(dtype=dtypes.bool), UPat(), UPat(dtype=dtypes.index))), lambda: True), + (UPat(Ops.WHERE, dtype=dtypes.weakint, src=(UPat(dtype=dtypes.bool), UPat(), UPat(dtype=dtypes.weakint))), lambda: True), # allow index dtype on a restricted set of UOps (UPat((Ops.ADD, Ops.MUL, Ops.MOD, Ops.IDIV, Ops.MAX, - Ops.SPECIAL, Ops.CAST, Ops.RANGE, Ops.VCONST, Ops.VECTORIZE), dtype=dtypes.index), lambda: True), + Ops.SPECIAL, Ops.CAST, Ops.RANGE, Ops.VCONST, Ops.VECTORIZE), dtype=dtypes.weakint), lambda: True), # while BIND is being casted - (UPat(Ops.BIND, (dtypes.int, dtypes.index), (UPat(), UPat()), arg=None), lambda: True), + (UPat(Ops.BIND, (dtypes.int, dtypes.weakint), (UPat(), UPat()), arg=None), lambda: True), # in progress MSTACK may lose device (UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True), diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index e1ce30ef91..8849cb6e80 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -51,15 +51,15 @@ def fold_add_divmod_recombine(x:UOp) -> UOp|None: # this needs to be before symbolic so that 0*something_that_might_be_invalid doesnt become 0 propagate_invalid = PatternMatcher([ # propagate invalid, push it past children - (invalid_gate.cast(name="cast"), lambda i,x,cond,cast: x.cast(cast.dtype) if i.dtype is dtypes.index else None), + (invalid_gate.cast(name="cast"), lambda i,x,cond,cast: x.cast(cast.dtype) if i.dtype is dtypes.weakint else None), (UPat(GroupOp.Unary, src=(invalid_gate,), name="alu"), lambda cond,x,alu,i: cond.where(x.alu(alu.op), i)), (UPat(GroupOp.Binary-GroupOp.Comparison, src=(invalid_gate, UPat.var("y")), name="alu"), lambda cond,x,y,alu,i: cond.where(x.alu(alu.op,y), i)), (UPat(GroupOp.Binary-GroupOp.Comparison, src=(UPat.var("y"), invalid_gate), name="alu"), lambda cond,x,y,alu,i: cond.where(y.alu(alu.op,x), i)), # TODO: when can this happen? and is it always safe to just drop invalid? (UPat(GroupOp.Comparison, src=(invalid_gate, UPat.var("y")), name="alu"), lambda cond,x,y,alu,i: - x.alu(alu.op,y) if i.dtype is dtypes.index else cond.where(x.alu(alu.op,y), i.cast(dtypes.bool))), + x.alu(alu.op,y) if i.dtype is dtypes.weakint else cond.where(x.alu(alu.op,y), i.cast(dtypes.bool))), (UPat(GroupOp.Comparison, src=(UPat.var("y"), invalid_gate), name="alu"), lambda cond,x,y,alu,i: - y.alu(alu.op,x) if i.dtype is dtypes.index else cond.where(y.alu(alu.op,x), i.cast(dtypes.bool))), + y.alu(alu.op,x) if i.dtype is dtypes.weakint else cond.where(y.alu(alu.op,x), i.cast(dtypes.bool))), # alu with invalid -> invalid (UPat(GroupOp.Unary, src=(invalid_pat,)), lambda i: i), (UPat(GroupOp.Binary-GroupOp.Comparison, src=[invalid_pat, UPat()]), lambda i: i), @@ -77,26 +77,26 @@ symbolic_simple = propagate_invalid + PatternMatcher([ # ** self folding ** (UPat.var("x") + 0, lambda x: x), # x+0 -> x (UPat.var("x") * 1, lambda x: x), # x*1 -> x - (UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)) ^ 0, lambda x: x), # x^0 -> x + (UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)) ^ 0, lambda x: x), # x^0 -> x (UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1 (UPat.var("x") // 1, lambda x: x), # x//1 -> x (UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x ((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed) # variations of (x%c)+(x//c)*c = x - (UPat(Ops.ADD, dtype=dtypes.index, name="x"), fold_add_divmod_recombine), + (UPat(Ops.ADD, dtype=dtypes.weakint, name="x"), fold_add_divmod_recombine), (UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c), (UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x), (UPat(GroupOp.Idempotent, src=(UPat.var("x"), UPat.var("x"))), lambda x: x), (UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x), (UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x), (UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, False), UPat.const(dtypes.bool, True)), lambda x: x.logical_not()), - (UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)).trunc(), lambda x: x), + (UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)).trunc(), lambda x: x), # ** zero folding ** (UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False (UPat.var("x") % UPat.var("x"), lambda x: x.const_like(0)), # x%x -> 0 (UPat.var("x") ^ UPat.var("x"), lambda x: x.const_like(0)), # x^x -> 0 (UPat.var("x") & 0, lambda x: x.const_like(0)), # x&0 -> 0 - (UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)) != UPat.var("x"), + (UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)) != UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints) # ** constant folding ** # TODO: add const folding for Ops.THREEFRY @@ -192,7 +192,7 @@ gep_pushing = PatternMatcher([ # GEP in order is removed (UPat(Ops.GEP, name="g"), lambda g: g.src[0] if not isinstance(g.dtype, PtrDType) and g.arg == tuple(range(g.src[0].dtype.count)) else None), # push all GEPs through ALUs for index (TODO: remove this) - (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu').f(Ops.GEP, dtype=dtypes.index, name='gep'), + (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu').f(Ops.GEP, dtype=dtypes.weakint, name='gep'), lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \ if not isinstance(gep.dtype, PtrDType) and not isinstance(alu.dtype, PtrDType) else None), # CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later) @@ -207,7 +207,7 @@ gep_pushing = PatternMatcher([ commutative = PatternMatcher([ # ** COMMUTATIVE flipping (only for index) ** # NOTE: this can break merging vector math by only flipping some of them - (UPat(GroupOp.Commutative, dtype=dtypes.index, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None), + (UPat(GroupOp.Commutative, dtype=dtypes.weakint, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None), ]) symbolic = symbolic_simple+commutative+PatternMatcher([ @@ -224,7 +224,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([ ((UPat.var("y") + UPat.var("x")) + UPat.var("x"), lambda y,x: y+x*2), ((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3) if x2 is not x3 else None), # (x/x2)/x3 -> x/(x2*x3) (-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c - (UPat.cvar("y") * (UPat.var("x", dtype=dtypes.index) + UPat.cvar("c")), lambda x,y,c: (y*x)+(y*c)), # y*(x+c) -> y*x + y*c + (UPat.cvar("y") * (UPat.var("x", dtype=dtypes.weakint) + UPat.cvar("c")), lambda x,y,c: (y*x)+(y*c)), # y*(x+c) -> y*x + y*c # ** where folding ** (UPat.var("cond", dtype=dtypes.bool).logical_not().where(UPat.var("t"), UPat.var("f")), lambda cond, t, f: cond.where(f,t) if f.arg is not Invalid else None), @@ -249,35 +249,35 @@ symbolic = symbolic_simple+commutative+PatternMatcher([ ((UPat.var("x") // UPat.cvar("c1")) // UPat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2) # ** lt ** # c0*x 0 and c1.arg > 0 else None), # c0*x 0 else x<(c.arg*d.arg-(d.arg-1))) if d.arg > 0 else None), # ** move add/mul consts to end (NOTE: this is still happening before constant folding) ** ((UPat.var("x") + UPat.cvar("c1")) + UPat.var("y"), lambda x,c1,y: (x+y)+c1), ((UPat.var("x") * UPat.cvar("c1")) * UPat.var("y"), lambda x,c1,y: (x*y)*c1), # *** rules from symbolic *** # generic lt folding - (UPat.var("x", dtypes.index) 0. NOTE: not x < 1 means x > 0 - ((UPat.var("x", dtypes.index)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None), + ((UPat.var("x", dtypes.weakint)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None), # a range mod its own upper bound is just the range (UPat(Ops.RANGE, src=UPat.var("end"), name="r")%UPat.var("end"), lambda r,end: r), (UPat(Ops.RANGE, src=UPat.var("end"), name="r")//UPat.var("end"), lambda r,end: r.const_like(0)), # cast/long folding # if the intermediate cast doesnt narrow we can do it in one cast (UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x.cast(b.dtype) if can_lossless_cast(x.dtype, a.dtype) else None), - (UPat.var('x', dtypes.ints+(dtypes.index,)).cast(dtypes.ints+(dtypes.index,), name="a").cast(name="b"), + (UPat.var('x', dtypes.ints+(dtypes.weakint,)).cast(dtypes.ints+(dtypes.weakint,), name="a").cast(name="b"), lambda x,a,b: x.cast(b.dtype) if a.dtype.min<=x.vmin and x.vmax<=a.dtype.max else None), # try to do math in int instead of long (UPat(GroupOp.Binary, src=(UPat.var("x", dtypes.long), UPat.var("y", dtypes.long)), name="u"), lambda u,x,y: x.cast(dtypes.int).alu(u.op, y.cast(dtypes.int)).cast(u.dtype) if not any(v.overflows(dtypes.int) for v in (u,x,y)) else None), - ((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)), + ((UPat.var("x", dtypes.weakint) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)), # only RANGE/IF/STORE/KERNEL have side effects (UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+ tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.BARRIER, Ops.END, Ops.UNROLL, Ops.LINEAR, Ops.BUFFERIZE} @@ -432,7 +432,7 @@ sym = symbolic+pm_simplify_valid+PatternMatcher([ UPat.load(UPat(Ops.INDEX, name="index"))), allow_any_len=True, name="store"), lambda index, gate, alt, store: UOp.store(index.src[0].index(gate.where(index.src[1], UOp.invalid())), alt, *store.src[2:])), # fold gated LOAD/STORE - (UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat.const(dtypes.index, Invalid)).or_casted(),), allow_any_len=True, name="x"), + (UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat.const(dtypes.weakint, Invalid)).or_casted(),), allow_any_len=True, name="x"), lambda x: UOp(Ops.NOOP) if x.op is Ops.STORE else x.const_like(0)), # invalid store does nothing. invalid load produces 0 (UPat(Ops.STORE, src=(UPat(), invalid_pat), allow_any_len=True), lambda i: UOp(Ops.NOOP)), # store of where with invalid -> gated store @@ -456,5 +456,5 @@ sym = symbolic+pm_simplify_valid+PatternMatcher([ # ** combine terms (opinionated) ** (-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y # (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue - ((UPat.var("x", dtypes.index) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c), + ((UPat.var("x", dtypes.weakint) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c), ]) diff --git a/tinygrad/uop/validate.py b/tinygrad/uop/validate.py index df9c3af232..b1f888bddf 100644 --- a/tinygrad/uop/validate.py +++ b/tinygrad/uop/validate.py @@ -20,33 +20,33 @@ def create_bounded(name:str, vmin:int, vmax:int, z3ctx:z3.Context) -> tuple[z3.A return (s:=z3.Int(name, ctx=z3ctx)), (vmin <= s)&(s <= vmax) z3_renderer = PatternMatcher([ - (UPat.var("cond").where(UPat.var("x"), UPat.const(dtypes.index, Invalid)), lambda x,cond,ctx: (ctx[1][x], ctx[1][cond])), + (UPat.var("cond").where(UPat.var("x"), UPat.const(dtypes.weakint, Invalid)), lambda x,cond,ctx: (ctx[1][x], ctx[1][cond])), # variables (UPat(Ops.SPECIAL, name="x"), lambda x,ctx: create_bounded(x.arg, 0, ctx[1][x.src[0]]-1, ctx[0])), (UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0])), (UPat(Ops.RANGE, name="x"), lambda x,ctx: create_bounded(x.render(simplify=False), 0, ctx[1][x.src[0]]-1, ctx[0])), # loads are variables bounded by the min/max of the dtype. non-pointer INDEX is also a LOAD - (UPat((Ops.LOAD, Ops.INDEX), dtypes.ints+(dtypes.index,), name="x"), lambda x,ctx: + (UPat((Ops.LOAD, Ops.INDEX), dtypes.ints+(dtypes.weakint,), name="x"), lambda x,ctx: create_bounded(f"load{len(ctx[1])}", x.dtype.min, x.dtype.max, ctx[0])), (UPat((Ops.LOAD, Ops.INDEX), dtypes.bool, name="x"), lambda x,ctx: (z3.Bool(f"load{len(ctx[1])}", ctx=ctx[0]), None)), # constants (UPat(Ops.CONST, arg=Invalid, name="x"), lambda x,ctx: (z3.Int("Invalid", ctx=ctx[0]), None)), - (UPat(Ops.CONST, dtypes.ints+(dtypes.index,), name="x"), lambda x,ctx: (z3.IntVal(x.arg, ctx=ctx[0]), None)), + (UPat(Ops.CONST, dtypes.ints+(dtypes.weakint,), name="x"), lambda x,ctx: (z3.IntVal(x.arg, ctx=ctx[0]), None)), (UPat(Ops.CONST, dtypes.bool, name="x"), lambda x,ctx: (z3.BoolVal(x.arg, ctx=ctx[0]), None)), # casts from floats create new variables - (UPat(Ops.CAST, dtypes.ints+(dtypes.index,), src=(UPat(dtype=dtypes.floats),), name="x"), lambda x,ctx: + (UPat(Ops.CAST, dtypes.ints+(dtypes.weakint,), src=(UPat(dtype=dtypes.floats),), name="x"), lambda x,ctx: create_bounded(f"cast{len(ctx[1])}", x.dtype.min, x.dtype.max, ctx[0])), # A comparison between floats introduces a new bool variable (UPat(GroupOp.Comparison, src=UPat(dtype=dtypes.floats), name="x"), lambda x,ctx: (z3.Bool(f"float_cmp{len(ctx[1])}", ctx=ctx[0]), None)), # casts from bool/int to int/bool - (UPat(Ops.CAST, dtypes.ints+(dtypes.index,),src=(UPat.var("x", dtypes.bool),), name="c"), lambda x,c,ctx: (z3.If(ctx[1][x], 1, 0), None)), - (UPat(Ops.CAST, dtypes.ints+(dtypes.index,), src=(UPat.var("x", dtypes.ints+(dtypes.index,)),), name="c"), lambda x,c,ctx: (ctx[1][x], None)), + (UPat(Ops.CAST, dtypes.ints+(dtypes.weakint,),src=(UPat.var("x", dtypes.bool),), name="c"), lambda x,c,ctx: (z3.If(ctx[1][x], 1, 0), None)), + (UPat(Ops.CAST, dtypes.ints+(dtypes.weakint,), src=(UPat.var("x", dtypes.ints+(dtypes.weakint,)),), name="c"), lambda x,c,ctx: (ctx[1][x], None)), (UPat(Ops.CAST, dtypes.bool, name="x"), lambda x,ctx: (ctx[1][x.src[0]]!=0, None)), (UPat(GroupOp.ALU, name="x"), lambda x,ctx: (z3_alu[x.op](*(ctx[1][s] for s in x.src)), None)), ]) def uops_to_z3(solver:z3.Solver, *uops: UOp) -> list[z3.ExprRef]: - lst = list(UOp.sink(*uops).toposort(gate=lambda x: x.dtype.scalar() in dtypes.ints+(dtypes.bool, dtypes.index) or x.op is Ops.SINK))[:-1] + lst = list(UOp.sink(*uops).toposort(gate=lambda x: x.dtype.scalar() in dtypes.ints+(dtypes.bool, dtypes.weakint) or x.op is Ops.SINK))[:-1] z3map: dict[UOp, z3.ExprRef] = {} for u in lst: z3_rewritten = z3_renderer.rewrite(u, ctx=(solver.ctx, z3map)) diff --git a/tinygrad/viz/js/worker.js b/tinygrad/viz/js/worker.js index 2ca5b6298f..9e838dbf7b 100644 --- a/tinygrad/viz/js/worker.js +++ b/tinygrad/viz/js/worker.js @@ -74,7 +74,7 @@ const layoutUOp = (g, { graph, change }, opts) => { if (!opts.showIndexing) { for (const n of g.nodes()) { const node = g.node(n); - if (node.label.includes("dtypes.index")) g.removeNode(n); + if (node.label.includes("dtypes.weakint")) g.removeNode(n); } } if (!opts.showCallSrc || opts.callSrcMask.size > 0) { diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index dfce868a6d..647bf90c9b 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -104,7 +104,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]: # always exclude DEVICE/CONST/UNIQUE if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u) if u.op is Ops.CONST and len(u.src) and u.src[0].op in {Ops.UNIQUE, Ops.LUNIQUE}: excluded.remove(u) - if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.index and u is not x: excluded.add(u) + if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.weakint and u is not x: excluded.add(u) if u.op is Ops.VECTORIZE and len(u.src) == 0: excluded.add(u) # exclude RESHAPE/EXPAND that only serve to broadcast a CONST if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)