From afad7d0cd15a73c728f2a8fa95bf2cdc817a96a4 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 29 Aug 2025 09:52:07 -0700 Subject: [PATCH] remove dtype from range, it will be dtypes.index soon [pr] (#11914) * remove dtype from range, it will be dtypes.index soon [pr] * a few more --- extra/gemm/amd_uop_matmul.py | 68 ++++++++++++++-------------- test/test_outerworld_range.py | 4 +- test/test_rangeify.py | 12 ++--- test/test_uop_graph.py | 10 ++-- test/test_uops.py | 2 +- test/unit/test_graph_rewrite.py | 10 ++-- test/unit/test_simplify_valid_idx.py | 2 +- tinygrad/codegen/lowerer.py | 3 +- tinygrad/schedule/rangeify.py | 2 +- tinygrad/shape/view.py | 2 +- tinygrad/uop/ops.py | 4 +- 11 files changed, 58 insertions(+), 61 deletions(-) diff --git a/extra/gemm/amd_uop_matmul.py b/extra/gemm/amd_uop_matmul.py index 9ad5351b57..78dbf81a1d 100644 --- a/extra/gemm/amd_uop_matmul.py +++ b/extra/gemm/amd_uop_matmul.py @@ -65,7 +65,7 @@ def top_spec_kernel3(): c = a@b sink = c.schedule()[-1].ast L = 16 - sink = sink.reshape((N//L, L, N//L, L)) #.lift({0:UOp.range(dtypes.int, N//BM, 0), 2:UOp.range(dtypes.int, N//BN, 1)}) + sink = sink.reshape((N//L, L, N//L, L)) #.lift({0:UOp.range(N//BM, 0), 2:UOp.range(N//BN, 1)}) sink = graph_rewrite(sink, view_left+pm) axis_types = (AxisType.GLOBAL, AxisType.LOCAL, AxisType.GLOBAL, AxisType.LOCAL, AxisType.REDUCE) return sink.replace(arg=KernelInfo(name="top_"+to_colored(sink.full_shape, axis_types), axis_types=axis_types)) @@ -186,7 +186,7 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)): c_regs = UOp(Ops.DEFINE_REG, dtypes.float.ptr(TM * nbIterWaveM * TN * nbIterWaveN), arg=2) - i = UOp.range(dtypes.int, c_regs.dtype.size, 16) + i = UOp.range(c_regs.dtype.size, 16) init_store = c_regs[i].store(UOp.const(dtypes.float, 0.0), i) if kernel4: @@ -197,53 +197,53 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)): kId = 0 # load from globals into locals - i = UOp.range(dtypes.int, nbReadsB, 0) + i = UOp.range(nbReadsB, 0) index_x = BN * blockIdx_x + rBIdx index_y = rBIdy + i * strideReadB + kId Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i) - i = UOp.range(dtypes.int, nbReadsA, 1) + i = UOp.range(nbReadsA, 1) index_x = rAIdx + kId index_y = BM * blockIdx_y + rAIdy + i * strideReadA As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(a[N * index_y + index_x].load(), i) # iterate over the middle chunk - kId_range = UOp.range(dtypes.int, N//BK-1, 2) + kId_range = UOp.range(N//BK-1, 2) kId = kId_range*BK barrier = UOp.barrier(As_store, Bs_store) # load from globals into registers (next round) - i = UOp.range(dtypes.int, nbReadsB, 3) + i = UOp.range(nbReadsB, 3) index_x = BN * blockIdx_x + rBIdx index_y = rBIdy + i * strideReadB + kId + BK regB_store = regB[i].store(b[N * index_y + index_x].load(), i) - i = UOp.range(dtypes.int, nbReadsA, 4) + i = UOp.range(nbReadsA, 4) index_x = rAIdx + kId + BK index_y = BM * blockIdx_y + rAIdy + i * strideReadA regA_store = regA[i].store(a[N * index_y + index_x].load(), i) def inner_loop(first_range, inp_dep=()): # inner unroll - k = UOp.range(dtypes.int, BK, first_range+0) + k = UOp.range(BK, first_range+0) # load from locals into registers - iterWave = UOp.range(dtypes.int, nbIterWaveN, first_range+1) - i = UOp.range(dtypes.int, TN, first_range+2) + iterWave = UOp.range(nbIterWaveN, first_range+1) + i = UOp.range(TN, first_range+2) index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i B_row_store = B_row[iterWave*TN + i].store(Bs[k*BN + index].load(*inp_dep), iterWave, i) - iterWave = UOp.range(dtypes.int, nbIterWaveM, first_range+3) - i = UOp.range(dtypes.int, TM, first_range+4) + iterWave = UOp.range(nbIterWaveM, first_range+3) + i = UOp.range(TM, first_range+4) index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i A_col_store = A_col[iterWave*TM + i].store(As[k*BM_As_stride + index].load(*inp_dep), iterWave, i) # do the GEMM math - iterWaveM = UOp.range(dtypes.int, nbIterWaveM, first_range+5) - yt = UOp.range(dtypes.int, TM, first_range+6) - iterWaveN = UOp.range(dtypes.int, nbIterWaveN, first_range+7) - xt = UOp.range(dtypes.int, TN, first_range+8) + iterWaveM = UOp.range(nbIterWaveM, first_range+5) + yt = UOp.range(TM, first_range+6) + iterWaveN = UOp.range(nbIterWaveN, first_range+7) + xt = UOp.range(TN, first_range+8) x = iterWaveN * TN + xt y = iterWaveM * TM + yt c_regs_idx = c_regs[y * TN * nbIterWaveN + x] @@ -256,12 +256,12 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)): sink = inner_loop(5, (barrier, regB_store, regA_store)).barrier() # load from registers into locals - i = UOp.range(dtypes.int, nbReadsB, 14) + i = UOp.range(nbReadsB, 14) index_x = BN * blockIdx_x + rBIdx index_y = rBIdy + i * strideReadB + kId + BK Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(regB[i].load(sink), i, kId_range) - i = UOp.range(dtypes.int, nbReadsA, 15) + i = UOp.range(nbReadsA, 15) index_x = rAIdx + kId + BK index_y = BM * blockIdx_y + rAIdy + i * strideReadA As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(regA[i].load(sink), i, kId_range) @@ -269,40 +269,40 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)): # final iteration without the copy sink = inner_loop(16, (UOp.barrier(Bs_store, As_store),)) else: - kId_range = UOp.range(dtypes.int, N//BK, 0) + kId_range = UOp.range(N//BK, 0) kId = kId_range*BK # load from globals into locals - i = UOp.range(dtypes.int, nbReadsB, 1) + i = UOp.range(nbReadsB, 1) index_x = BN * blockIdx_x + rBIdx index_y = rBIdy + i * strideReadB + kId Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i) - i = UOp.range(dtypes.int, nbReadsA, 2) + i = UOp.range(nbReadsA, 2) index_x = rAIdx + kId index_y = BM * blockIdx_y + rAIdy + i * strideReadA As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(a[N * index_y + index_x].load(), i) barrier = UOp.barrier(As_store, Bs_store) - k = UOp.range(dtypes.int, BK, 3) + k = UOp.range(BK, 3) # load from locals into registers - iterWave = UOp.range(dtypes.int, nbIterWaveN, 4) - i = UOp.range(dtypes.int, TN, 5) + iterWave = UOp.range(nbIterWaveN, 4) + i = UOp.range(TN, 5) index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i B_row_store = B_row[iterWave*TN + i].store(Bs[k*BN + index].load(barrier), iterWave, i) - iterWave = UOp.range(dtypes.int, nbIterWaveM, 6) - i = UOp.range(dtypes.int, TM, 7) + iterWave = UOp.range(nbIterWaveM, 6) + i = UOp.range(TM, 7) index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i A_col_store = A_col[iterWave*TM + i].store(As[k*BM_As_stride + index].load(barrier), iterWave, i) # do the GEMM math - iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 8) - yt = UOp.range(dtypes.int, TM, 9) - iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 10) - xt = UOp.range(dtypes.int, TN, 12) + iterWaveM = UOp.range(nbIterWaveM, 8) + yt = UOp.range(TM, 9) + iterWaveN = UOp.range(nbIterWaveN, 10) + xt = UOp.range(TN, 12) x = iterWaveN * TN + xt y = iterWaveM * TM + yt c_regs_idx = c_regs[y * TN * nbIterWaveN + x] @@ -310,10 +310,10 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)): iterWaveM, iterWaveN, yt, xt, k, kId_range) # store c_regs into c - iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 1000) - yt = UOp.range(dtypes.int, TM, 1001) - iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 1002) - xt = UOp.range(dtypes.int, TN, 1003) + iterWaveM = UOp.range(nbIterWaveM, 1000) + yt = UOp.range(TM, 1001) + iterWaveN = UOp.range(nbIterWaveN, 1002) + xt = UOp.range(TN, 1003) xOut = blockIdx_x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave yOut = blockIdx_y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave indexC = N * (yOut + yt) + xOut + xt diff --git a/test/test_outerworld_range.py b/test/test_outerworld_range.py index 36fac3e2ab..cfc610cde8 100644 --- a/test/test_outerworld_range.py +++ b/test/test_outerworld_range.py @@ -1,5 +1,5 @@ import unittest -from tinygrad import Tensor, nn, Variable, UOp, dtypes +from tinygrad import Tensor, nn, Variable, UOp # outerworld range should support three things # 1. full optimizer steps (test_model_bound_range) @@ -136,7 +136,7 @@ class TestOuterworldRange(unittest.TestCase): def test_model_bound_range(self): m, opt = get_model_and_opt() # TODO: should ranges be unique so you don't have to pass in the -1? - rng = UOp.range(dtypes.int, self.STEPS, -1) + rng = UOp.range(self.STEPS, -1) vib = Variable('i', 0, self.STEPS-1).bind(rng) loss = (m(self.X[vib]) - self.Y[vib]).square().mean() loss.backward() diff --git a/test/test_rangeify.py b/test/test_rangeify.py index c2c779ed42..8558c333ea 100644 --- a/test/test_rangeify.py +++ b/test/test_rangeify.py @@ -1,6 +1,7 @@ import unittest from tinygrad import Tensor from tinygrad.helpers import RANGEIFY, Context, GlobalCounters +from tinygrad.uop.ops import UOp N = 256 @@ -141,9 +142,6 @@ class TestRangeify(unittest.TestCase): print(f"mse: {mse}") self.assertLessEqual(mse, 1e-6) -from tinygrad import dtypes -from tinygrad.uop.ops import UOp - # contiguous + reduce can support ranges? @unittest.skipIf(RANGEIFY<1, "tests only for RANGEIFY") @@ -152,7 +150,7 @@ class TestOuterworld(unittest.TestCase): t = Tensor.rand(10, 10).realize() # passthrough ranges - a = UOp.range(dtypes.int, 10, -1) + a = UOp.range(10, -1) sel = t[a] cpy = sel.contiguous(a).realize() @@ -162,7 +160,7 @@ class TestOuterworld(unittest.TestCase): t = Tensor.rand(10, 10).realize() # passthrough ranges - a = UOp.range(dtypes.int, 10, -1) + a = UOp.range(10, -1) sel = t[9-a] cpy = sel.contiguous(a).realize() @@ -174,7 +172,7 @@ class TestOuterworld(unittest.TestCase): x = Tensor.ones(3, 10, 2).contiguous() # vmap across axis 0 - a = UOp.range(dtypes.int, 3, -1) + a = UOp.range(3, -1) out = f(x[a]) out = out.contiguous(a) @@ -188,7 +186,7 @@ class TestOuterworld(unittest.TestCase): manual = (x @ W[0] @ W[1] @ W[2]).contiguous().realize() - a = UOp.range(dtypes.int, 3, -1) + a = UOp.range(3, -1) x = x.assign(x @ W[a]) out = x.contiguous(a)[-1].contiguous().realize() diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 945c2b6a38..b56fe1a922 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -477,7 +477,7 @@ class TestUOpGraph(unittest.TestCase): def test_load_with_float_in_index(self): with Context(IGNORE_OOB=0): - ridx = UOp.range(dtypes.int, 20, 0) + ridx = UOp.range(20, 0) glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) i = (ridx.cast(dtypes.float)*0.68).trunc().cast(dtypes.int) ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16))),)) @@ -490,7 +490,7 @@ class TestUOpGraph(unittest.TestCase): def test_load_cast_to_bool(self): with Context(IGNORE_OOB=0): glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0) - ridx = UOp.range(dtypes.int, 20, 0) + ridx = UOp.range(20, 0) ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx, ridx.cast(dtypes.bool).logical_not()),)) to_uops_list([ld0]) @@ -499,7 +499,7 @@ class TestUOpGraph(unittest.TestCase): with Context(IGNORE_OOB=0): glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) mask = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0) - ridx = UOp.range(dtypes.int, 20, 0) + ridx = UOp.range(20, 0) ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask),))) to_uops_list([ld0]) @@ -592,8 +592,8 @@ class TestUOpGraph(unittest.TestCase): def test_switched_range_order(self): glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) cf = UOp.const(dtypes.float, 0.0) - r1 = UOp.range(dtypes.int, 2, 0) - r2 = UOp.range(dtypes.int, 2, 1) + r1 = UOp.range(2, 0) + r2 = UOp.range(2, 1) alu = UOp(Ops.MUL, dtypes.int, (r2, r1)) store = UOp(Ops.STORE, dtypes.void, (glbl.index(alu), cf)) uops = to_uops_list([store]) diff --git a/test/test_uops.py b/test/test_uops.py index 122292b716..e1af64daf4 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -403,7 +403,7 @@ class TestAssembly(unittest.TestCase): self.assertNotIn(Ops.IDIV, ops) def test_fast_idiv_remove_powers_of_two(self): - ridx = UOp.range(dtypes.int, 2**20, 0) + ridx = UOp.range(2**20, 0) uops = to_uops_list([ridx//(7*64)], opts=Device[Device.DEFAULT].renderer) ops = [x.op for x in uops] # this requires shifting out the powers of two before doing fast_idiv diff --git a/test/unit/test_graph_rewrite.py b/test/unit/test_graph_rewrite.py index 30d7454210..df12fb81a6 100644 --- a/test/unit/test_graph_rewrite.py +++ b/test/unit/test_graph_rewrite.py @@ -65,21 +65,21 @@ class TestFoldingAndReduction(unittest.TestCase): def test_full_graph_rewrite_reduction_with_unused_range(self): const1 = UOp.const(dtypes.int32, 15) const2 = UOp.const(dtypes.int32, 25) - rng = UOp.range(dtypes.int32, 10, idx=0) + rng = UOp.range(10, idx=0) optimized_sink = apply_rewrite((const1 + const2).reduce(Ops.ADD, rng)) expected_sum = 10 * (15 + 25) self.assertEqual(optimized_sink.arg, expected_sum) @unittest.skip("currently failing") def test_full_graph_rewrite_range_reduction(self): - simple_range = UOp.range(dtypes.int32, 5, idx=0) + simple_range = UOp.range(5, idx=0) optimized_sink = apply_rewrite(simple_range.reduce(Ops.ADD, simple_range)) expected_sum = sum(range(5)) self.assertEqual(optimized_sink.arg, expected_sum) @unittest.skip("currently failing") def test_full_graph_rewrite_simple_reduction_folding(self): - simple_range = UOp.range(dtypes.int32, 4, idx=0) + simple_range = UOp.range(4, idx=0) add_uop = simple_range + UOp.const(dtypes.int32, 1) optimized_sink = apply_rewrite(add_uop.reduce(Ops.ADD, simple_range)) expected_sum = sum(i + 1 for i in range(4)) @@ -87,8 +87,8 @@ class TestFoldingAndReduction(unittest.TestCase): @unittest.skip("currently failing") def test_full_graph_rewrite_nested_loop_collapse(self): - outer_range = UOp.range(dtypes.int32, 8, 0) - inner_range = UOp.range(dtypes.int32, 4, 1) + outer_range = UOp.range(8, 0) + inner_range = UOp.range(4, 1) expr = (outer_range * 10) + inner_range optimized_reduce_uop = apply_rewrite(expr.reduce(Ops.ADD, outer_range, inner_range)) self.assertEqual(optimized_reduce_uop.op, Ops.CONST) diff --git a/test/unit/test_simplify_valid_idx.py b/test/unit/test_simplify_valid_idx.py index 47b8ce4ff5..77869f5b17 100644 --- a/test/unit/test_simplify_valid_idx.py +++ b/test/unit/test_simplify_valid_idx.py @@ -19,7 +19,7 @@ def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UO def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.int, (), (expr, nmax)) def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax) -def Range(n, nmax): return UOp.range(dtypes.int, nmax, n) +def Range(n, nmax): return UOp.range(nmax, n) class TestHelpers(unittest.TestCase): def test_is_increasing(self): diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index b9d63f6f87..d7811f5bcd 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -1,6 +1,5 @@ # the job of the lowerer is to do indexing from dataclasses import dataclass -from tinygrad.dtype import dtypes from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop, AxisType, graph_rewrite, resolve # ***** indexing ***** @@ -12,7 +11,7 @@ class IndexContext: start: int = 0 def shape_to_idx(s, axis_types, start=0): - return [UOp.range(dtypes.int, sint_to_uop(s), start+i, at) for i, (s, at) in enumerate(zip(s, axis_types))] + return [UOp.range(sint_to_uop(s), start+i, at) for i, (s, at) in enumerate(zip(s, axis_types))] def get_index(ast:UOp) -> IndexContext: axis_types = ast.arg.axis_types if isinstance(ast.arg, KernelInfo) else () diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 83578efb71..4a527edd57 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -109,7 +109,7 @@ class RangeifyContext: # create ranges range_idx: int = 0 def new_range(self, s:sint, axistype:AxisType=AxisType.LOOP): - ret = UOp.range(dtypes.int, s, self.range_idx, axistype) + ret = UOp.range(s, self.range_idx, axistype) self.range_idx += 1 return ret diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index fcef6265ae..6b4cd22bb6 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -114,7 +114,7 @@ class View: def to_indexed_uops(self:View, idxs:Sequence[UOp]|None=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> tuple[UOp, UOp]: """(idx, valid)""" - if idxs is None: idxs = [UOp.range(dtypes.int, s, i) for i,s in enumerate(self.shape)] + if idxs is None: idxs = [UOp.range(s, i) for i,s in enumerate(self.shape)] iexpr = sint_to_uop(self.offset) for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else itertools.repeat(None)): if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 25dd38e9d2..784c5c070d 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -297,10 +297,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass): else: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),)) return ret @staticmethod - def range(dtype:DType, end:sint, *arg): + def range(end:sint, *arg): if len(arg) == 0: raise RuntimeError("range needs an arg") if len(arg) == 1: arg = arg+(AxisType.LOOP,) - return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end),), arg=arg) + return UOp(Ops.RANGE, dtype=dtypes.int, src=(sint_to_uop(end),), arg=arg) def r(self, op:Ops, axis:tuple[int, ...]): axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)])) if len(axis) == 0: return self