remove dtype from range, it will be dtypes.index soon [pr] (#11914)

* remove dtype from range, it will be dtypes.index soon [pr]

* a few more
This commit is contained in:
George Hotz
2025-08-29 09:52:07 -07:00
committed by GitHub
parent 30e72d5820
commit afad7d0cd1
11 changed files with 58 additions and 61 deletions

View File

@@ -65,7 +65,7 @@ def top_spec_kernel3():
c = a@b
sink = c.schedule()[-1].ast
L = 16
sink = sink.reshape((N//L, L, N//L, L)) #.lift({0:UOp.range(dtypes.int, N//BM, 0), 2:UOp.range(dtypes.int, N//BN, 1)})
sink = sink.reshape((N//L, L, N//L, L)) #.lift({0:UOp.range(N//BM, 0), 2:UOp.range(N//BN, 1)})
sink = graph_rewrite(sink, view_left+pm)
axis_types = (AxisType.GLOBAL, AxisType.LOCAL, AxisType.GLOBAL, AxisType.LOCAL, AxisType.REDUCE)
return sink.replace(arg=KernelInfo(name="top_"+to_colored(sink.full_shape, axis_types), axis_types=axis_types))
@@ -186,7 +186,7 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)):
c_regs = UOp(Ops.DEFINE_REG, dtypes.float.ptr(TM * nbIterWaveM * TN * nbIterWaveN), arg=2)
i = UOp.range(dtypes.int, c_regs.dtype.size, 16)
i = UOp.range(c_regs.dtype.size, 16)
init_store = c_regs[i].store(UOp.const(dtypes.float, 0.0), i)
if kernel4:
@@ -197,53 +197,53 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)):
kId = 0
# load from globals into locals
i = UOp.range(dtypes.int, nbReadsB, 0)
i = UOp.range(nbReadsB, 0)
index_x = BN * blockIdx_x + rBIdx
index_y = rBIdy + i * strideReadB + kId
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i)
i = UOp.range(dtypes.int, nbReadsA, 1)
i = UOp.range(nbReadsA, 1)
index_x = rAIdx + kId
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(a[N * index_y + index_x].load(), i)
# iterate over the middle chunk
kId_range = UOp.range(dtypes.int, N//BK-1, 2)
kId_range = UOp.range(N//BK-1, 2)
kId = kId_range*BK
barrier = UOp.barrier(As_store, Bs_store)
# load from globals into registers (next round)
i = UOp.range(dtypes.int, nbReadsB, 3)
i = UOp.range(nbReadsB, 3)
index_x = BN * blockIdx_x + rBIdx
index_y = rBIdy + i * strideReadB + kId + BK
regB_store = regB[i].store(b[N * index_y + index_x].load(), i)
i = UOp.range(dtypes.int, nbReadsA, 4)
i = UOp.range(nbReadsA, 4)
index_x = rAIdx + kId + BK
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
regA_store = regA[i].store(a[N * index_y + index_x].load(), i)
def inner_loop(first_range, inp_dep=()):
# inner unroll
k = UOp.range(dtypes.int, BK, first_range+0)
k = UOp.range(BK, first_range+0)
# load from locals into registers
iterWave = UOp.range(dtypes.int, nbIterWaveN, first_range+1)
i = UOp.range(dtypes.int, TN, first_range+2)
iterWave = UOp.range(nbIterWaveN, first_range+1)
i = UOp.range(TN, first_range+2)
index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i
B_row_store = B_row[iterWave*TN + i].store(Bs[k*BN + index].load(*inp_dep), iterWave, i)
iterWave = UOp.range(dtypes.int, nbIterWaveM, first_range+3)
i = UOp.range(dtypes.int, TM, first_range+4)
iterWave = UOp.range(nbIterWaveM, first_range+3)
i = UOp.range(TM, first_range+4)
index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i
A_col_store = A_col[iterWave*TM + i].store(As[k*BM_As_stride + index].load(*inp_dep), iterWave, i)
# do the GEMM math
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, first_range+5)
yt = UOp.range(dtypes.int, TM, first_range+6)
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, first_range+7)
xt = UOp.range(dtypes.int, TN, first_range+8)
iterWaveM = UOp.range(nbIterWaveM, first_range+5)
yt = UOp.range(TM, first_range+6)
iterWaveN = UOp.range(nbIterWaveN, first_range+7)
xt = UOp.range(TN, first_range+8)
x = iterWaveN * TN + xt
y = iterWaveM * TM + yt
c_regs_idx = c_regs[y * TN * nbIterWaveN + x]
@@ -256,12 +256,12 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)):
sink = inner_loop(5, (barrier, regB_store, regA_store)).barrier()
# load from registers into locals
i = UOp.range(dtypes.int, nbReadsB, 14)
i = UOp.range(nbReadsB, 14)
index_x = BN * blockIdx_x + rBIdx
index_y = rBIdy + i * strideReadB + kId + BK
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(regB[i].load(sink), i, kId_range)
i = UOp.range(dtypes.int, nbReadsA, 15)
i = UOp.range(nbReadsA, 15)
index_x = rAIdx + kId + BK
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(regA[i].load(sink), i, kId_range)
@@ -269,40 +269,40 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)):
# final iteration without the copy
sink = inner_loop(16, (UOp.barrier(Bs_store, As_store),))
else:
kId_range = UOp.range(dtypes.int, N//BK, 0)
kId_range = UOp.range(N//BK, 0)
kId = kId_range*BK
# load from globals into locals
i = UOp.range(dtypes.int, nbReadsB, 1)
i = UOp.range(nbReadsB, 1)
index_x = BN * blockIdx_x + rBIdx
index_y = rBIdy + i * strideReadB + kId
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i)
i = UOp.range(dtypes.int, nbReadsA, 2)
i = UOp.range(nbReadsA, 2)
index_x = rAIdx + kId
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(a[N * index_y + index_x].load(), i)
barrier = UOp.barrier(As_store, Bs_store)
k = UOp.range(dtypes.int, BK, 3)
k = UOp.range(BK, 3)
# load from locals into registers
iterWave = UOp.range(dtypes.int, nbIterWaveN, 4)
i = UOp.range(dtypes.int, TN, 5)
iterWave = UOp.range(nbIterWaveN, 4)
i = UOp.range(TN, 5)
index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i
B_row_store = B_row[iterWave*TN + i].store(Bs[k*BN + index].load(barrier), iterWave, i)
iterWave = UOp.range(dtypes.int, nbIterWaveM, 6)
i = UOp.range(dtypes.int, TM, 7)
iterWave = UOp.range(nbIterWaveM, 6)
i = UOp.range(TM, 7)
index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i
A_col_store = A_col[iterWave*TM + i].store(As[k*BM_As_stride + index].load(barrier), iterWave, i)
# do the GEMM math
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 8)
yt = UOp.range(dtypes.int, TM, 9)
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 10)
xt = UOp.range(dtypes.int, TN, 12)
iterWaveM = UOp.range(nbIterWaveM, 8)
yt = UOp.range(TM, 9)
iterWaveN = UOp.range(nbIterWaveN, 10)
xt = UOp.range(TN, 12)
x = iterWaveN * TN + xt
y = iterWaveM * TM + yt
c_regs_idx = c_regs[y * TN * nbIterWaveN + x]
@@ -310,10 +310,10 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)):
iterWaveM, iterWaveN, yt, xt, k, kId_range)
# store c_regs into c
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 1000)
yt = UOp.range(dtypes.int, TM, 1001)
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 1002)
xt = UOp.range(dtypes.int, TN, 1003)
iterWaveM = UOp.range(nbIterWaveM, 1000)
yt = UOp.range(TM, 1001)
iterWaveN = UOp.range(nbIterWaveN, 1002)
xt = UOp.range(TN, 1003)
xOut = blockIdx_x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave
yOut = blockIdx_y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave
indexC = N * (yOut + yt) + xOut + xt

View File

@@ -1,5 +1,5 @@
import unittest
from tinygrad import Tensor, nn, Variable, UOp, dtypes
from tinygrad import Tensor, nn, Variable, UOp
# outerworld range should support three things
# 1. full optimizer steps (test_model_bound_range)
@@ -136,7 +136,7 @@ class TestOuterworldRange(unittest.TestCase):
def test_model_bound_range(self):
m, opt = get_model_and_opt()
# TODO: should ranges be unique so you don't have to pass in the -1?
rng = UOp.range(dtypes.int, self.STEPS, -1)
rng = UOp.range(self.STEPS, -1)
vib = Variable('i', 0, self.STEPS-1).bind(rng)
loss = (m(self.X[vib]) - self.Y[vib]).square().mean()
loss.backward()

View File

@@ -1,6 +1,7 @@
import unittest
from tinygrad import Tensor
from tinygrad.helpers import RANGEIFY, Context, GlobalCounters
from tinygrad.uop.ops import UOp
N = 256
@@ -141,9 +142,6 @@ class TestRangeify(unittest.TestCase):
print(f"mse: {mse}")
self.assertLessEqual(mse, 1e-6)
from tinygrad import dtypes
from tinygrad.uop.ops import UOp
# contiguous + reduce can support ranges?
@unittest.skipIf(RANGEIFY<1, "tests only for RANGEIFY")
@@ -152,7 +150,7 @@ class TestOuterworld(unittest.TestCase):
t = Tensor.rand(10, 10).realize()
# passthrough ranges
a = UOp.range(dtypes.int, 10, -1)
a = UOp.range(10, -1)
sel = t[a]
cpy = sel.contiguous(a).realize()
@@ -162,7 +160,7 @@ class TestOuterworld(unittest.TestCase):
t = Tensor.rand(10, 10).realize()
# passthrough ranges
a = UOp.range(dtypes.int, 10, -1)
a = UOp.range(10, -1)
sel = t[9-a]
cpy = sel.contiguous(a).realize()
@@ -174,7 +172,7 @@ class TestOuterworld(unittest.TestCase):
x = Tensor.ones(3, 10, 2).contiguous()
# vmap across axis 0
a = UOp.range(dtypes.int, 3, -1)
a = UOp.range(3, -1)
out = f(x[a])
out = out.contiguous(a)
@@ -188,7 +186,7 @@ class TestOuterworld(unittest.TestCase):
manual = (x @ W[0] @ W[1] @ W[2]).contiguous().realize()
a = UOp.range(dtypes.int, 3, -1)
a = UOp.range(3, -1)
x = x.assign(x @ W[a])
out = x.contiguous(a)[-1].contiguous().realize()

View File

@@ -477,7 +477,7 @@ class TestUOpGraph(unittest.TestCase):
def test_load_with_float_in_index(self):
with Context(IGNORE_OOB=0):
ridx = UOp.range(dtypes.int, 20, 0)
ridx = UOp.range(20, 0)
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
i = (ridx.cast(dtypes.float)*0.68).trunc().cast(dtypes.int)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16))),))
@@ -490,7 +490,7 @@ class TestUOpGraph(unittest.TestCase):
def test_load_cast_to_bool(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0)
ridx = UOp.range(dtypes.int, 20, 0)
ridx = UOp.range(20, 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx, ridx.cast(dtypes.bool).logical_not()),))
to_uops_list([ld0])
@@ -499,7 +499,7 @@ class TestUOpGraph(unittest.TestCase):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
mask = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
ridx = UOp.range(dtypes.int, 20, 0)
ridx = UOp.range(20, 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask),)))
to_uops_list([ld0])
@@ -592,8 +592,8 @@ class TestUOpGraph(unittest.TestCase):
def test_switched_range_order(self):
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
cf = UOp.const(dtypes.float, 0.0)
r1 = UOp.range(dtypes.int, 2, 0)
r2 = UOp.range(dtypes.int, 2, 1)
r1 = UOp.range(2, 0)
r2 = UOp.range(2, 1)
alu = UOp(Ops.MUL, dtypes.int, (r2, r1))
store = UOp(Ops.STORE, dtypes.void, (glbl.index(alu), cf))
uops = to_uops_list([store])

View File

@@ -403,7 +403,7 @@ class TestAssembly(unittest.TestCase):
self.assertNotIn(Ops.IDIV, ops)
def test_fast_idiv_remove_powers_of_two(self):
ridx = UOp.range(dtypes.int, 2**20, 0)
ridx = UOp.range(2**20, 0)
uops = to_uops_list([ridx//(7*64)], opts=Device[Device.DEFAULT].renderer)
ops = [x.op for x in uops]
# this requires shifting out the powers of two before doing fast_idiv

View File

@@ -65,21 +65,21 @@ class TestFoldingAndReduction(unittest.TestCase):
def test_full_graph_rewrite_reduction_with_unused_range(self):
const1 = UOp.const(dtypes.int32, 15)
const2 = UOp.const(dtypes.int32, 25)
rng = UOp.range(dtypes.int32, 10, idx=0)
rng = UOp.range(10, idx=0)
optimized_sink = apply_rewrite((const1 + const2).reduce(Ops.ADD, rng))
expected_sum = 10 * (15 + 25)
self.assertEqual(optimized_sink.arg, expected_sum)
@unittest.skip("currently failing")
def test_full_graph_rewrite_range_reduction(self):
simple_range = UOp.range(dtypes.int32, 5, idx=0)
simple_range = UOp.range(5, idx=0)
optimized_sink = apply_rewrite(simple_range.reduce(Ops.ADD, simple_range))
expected_sum = sum(range(5))
self.assertEqual(optimized_sink.arg, expected_sum)
@unittest.skip("currently failing")
def test_full_graph_rewrite_simple_reduction_folding(self):
simple_range = UOp.range(dtypes.int32, 4, idx=0)
simple_range = UOp.range(4, idx=0)
add_uop = simple_range + UOp.const(dtypes.int32, 1)
optimized_sink = apply_rewrite(add_uop.reduce(Ops.ADD, simple_range))
expected_sum = sum(i + 1 for i in range(4))
@@ -87,8 +87,8 @@ class TestFoldingAndReduction(unittest.TestCase):
@unittest.skip("currently failing")
def test_full_graph_rewrite_nested_loop_collapse(self):
outer_range = UOp.range(dtypes.int32, 8, 0)
inner_range = UOp.range(dtypes.int32, 4, 1)
outer_range = UOp.range(8, 0)
inner_range = UOp.range(4, 1)
expr = (outer_range * 10) + inner_range
optimized_reduce_uop = apply_rewrite(expr.reduce(Ops.ADD, outer_range, inner_range))
self.assertEqual(optimized_reduce_uop.op, Ops.CONST)

View File

@@ -19,7 +19,7 @@ def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UO
def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.int, (), (expr, nmax))
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
def Range(n, nmax): return UOp.range(dtypes.int, nmax, n)
def Range(n, nmax): return UOp.range(nmax, n)
class TestHelpers(unittest.TestCase):
def test_is_increasing(self):

View File

@@ -1,6 +1,5 @@
# the job of the lowerer is to do indexing
from dataclasses import dataclass
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop, AxisType, graph_rewrite, resolve
# ***** indexing *****
@@ -12,7 +11,7 @@ class IndexContext:
start: int = 0
def shape_to_idx(s, axis_types, start=0):
return [UOp.range(dtypes.int, sint_to_uop(s), start+i, at) for i, (s, at) in enumerate(zip(s, axis_types))]
return [UOp.range(sint_to_uop(s), start+i, at) for i, (s, at) in enumerate(zip(s, axis_types))]
def get_index(ast:UOp) -> IndexContext:
axis_types = ast.arg.axis_types if isinstance(ast.arg, KernelInfo) else ()

View File

@@ -109,7 +109,7 @@ class RangeifyContext:
# create ranges
range_idx: int = 0
def new_range(self, s:sint, axistype:AxisType=AxisType.LOOP):
ret = UOp.range(dtypes.int, s, self.range_idx, axistype)
ret = UOp.range(s, self.range_idx, axistype)
self.range_idx += 1
return ret

View File

@@ -114,7 +114,7 @@ class View:
def to_indexed_uops(self:View, idxs:Sequence[UOp]|None=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> tuple[UOp, UOp]:
"""(idx, valid)"""
if idxs is None: idxs = [UOp.range(dtypes.int, s, i) for i,s in enumerate(self.shape)]
if idxs is None: idxs = [UOp.range(s, i) for i,s in enumerate(self.shape)]
iexpr = sint_to_uop(self.offset)
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else itertools.repeat(None)):
if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st

View File

@@ -297,10 +297,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
else: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),))
return ret
@staticmethod
def range(dtype:DType, end:sint, *arg):
def range(end:sint, *arg):
if len(arg) == 0: raise RuntimeError("range needs an arg")
if len(arg) == 1: arg = arg+(AxisType.LOOP,)
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end),), arg=arg)
return UOp(Ops.RANGE, dtype=dtypes.int, src=(sint_to_uop(end),), arg=arg)
def r(self, op:Ops, axis:tuple[int, ...]):
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
if len(axis) == 0: return self