mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
remove dtype from range, it will be dtypes.index soon [pr] (#11914)
* remove dtype from range, it will be dtypes.index soon [pr] * a few more
This commit is contained in:
@@ -65,7 +65,7 @@ def top_spec_kernel3():
|
|||||||
c = a@b
|
c = a@b
|
||||||
sink = c.schedule()[-1].ast
|
sink = c.schedule()[-1].ast
|
||||||
L = 16
|
L = 16
|
||||||
sink = sink.reshape((N//L, L, N//L, L)) #.lift({0:UOp.range(dtypes.int, N//BM, 0), 2:UOp.range(dtypes.int, N//BN, 1)})
|
sink = sink.reshape((N//L, L, N//L, L)) #.lift({0:UOp.range(N//BM, 0), 2:UOp.range(N//BN, 1)})
|
||||||
sink = graph_rewrite(sink, view_left+pm)
|
sink = graph_rewrite(sink, view_left+pm)
|
||||||
axis_types = (AxisType.GLOBAL, AxisType.LOCAL, AxisType.GLOBAL, AxisType.LOCAL, AxisType.REDUCE)
|
axis_types = (AxisType.GLOBAL, AxisType.LOCAL, AxisType.GLOBAL, AxisType.LOCAL, AxisType.REDUCE)
|
||||||
return sink.replace(arg=KernelInfo(name="top_"+to_colored(sink.full_shape, axis_types), axis_types=axis_types))
|
return sink.replace(arg=KernelInfo(name="top_"+to_colored(sink.full_shape, axis_types), axis_types=axis_types))
|
||||||
@@ -186,7 +186,7 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)):
|
|||||||
|
|
||||||
c_regs = UOp(Ops.DEFINE_REG, dtypes.float.ptr(TM * nbIterWaveM * TN * nbIterWaveN), arg=2)
|
c_regs = UOp(Ops.DEFINE_REG, dtypes.float.ptr(TM * nbIterWaveM * TN * nbIterWaveN), arg=2)
|
||||||
|
|
||||||
i = UOp.range(dtypes.int, c_regs.dtype.size, 16)
|
i = UOp.range(c_regs.dtype.size, 16)
|
||||||
init_store = c_regs[i].store(UOp.const(dtypes.float, 0.0), i)
|
init_store = c_regs[i].store(UOp.const(dtypes.float, 0.0), i)
|
||||||
|
|
||||||
if kernel4:
|
if kernel4:
|
||||||
@@ -197,53 +197,53 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)):
|
|||||||
kId = 0
|
kId = 0
|
||||||
|
|
||||||
# load from globals into locals
|
# load from globals into locals
|
||||||
i = UOp.range(dtypes.int, nbReadsB, 0)
|
i = UOp.range(nbReadsB, 0)
|
||||||
index_x = BN * blockIdx_x + rBIdx
|
index_x = BN * blockIdx_x + rBIdx
|
||||||
index_y = rBIdy + i * strideReadB + kId
|
index_y = rBIdy + i * strideReadB + kId
|
||||||
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i)
|
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i)
|
||||||
|
|
||||||
i = UOp.range(dtypes.int, nbReadsA, 1)
|
i = UOp.range(nbReadsA, 1)
|
||||||
index_x = rAIdx + kId
|
index_x = rAIdx + kId
|
||||||
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||||
As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(a[N * index_y + index_x].load(), i)
|
As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(a[N * index_y + index_x].load(), i)
|
||||||
|
|
||||||
# iterate over the middle chunk
|
# iterate over the middle chunk
|
||||||
kId_range = UOp.range(dtypes.int, N//BK-1, 2)
|
kId_range = UOp.range(N//BK-1, 2)
|
||||||
kId = kId_range*BK
|
kId = kId_range*BK
|
||||||
|
|
||||||
barrier = UOp.barrier(As_store, Bs_store)
|
barrier = UOp.barrier(As_store, Bs_store)
|
||||||
|
|
||||||
# load from globals into registers (next round)
|
# load from globals into registers (next round)
|
||||||
i = UOp.range(dtypes.int, nbReadsB, 3)
|
i = UOp.range(nbReadsB, 3)
|
||||||
index_x = BN * blockIdx_x + rBIdx
|
index_x = BN * blockIdx_x + rBIdx
|
||||||
index_y = rBIdy + i * strideReadB + kId + BK
|
index_y = rBIdy + i * strideReadB + kId + BK
|
||||||
regB_store = regB[i].store(b[N * index_y + index_x].load(), i)
|
regB_store = regB[i].store(b[N * index_y + index_x].load(), i)
|
||||||
|
|
||||||
i = UOp.range(dtypes.int, nbReadsA, 4)
|
i = UOp.range(nbReadsA, 4)
|
||||||
index_x = rAIdx + kId + BK
|
index_x = rAIdx + kId + BK
|
||||||
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||||
regA_store = regA[i].store(a[N * index_y + index_x].load(), i)
|
regA_store = regA[i].store(a[N * index_y + index_x].load(), i)
|
||||||
|
|
||||||
def inner_loop(first_range, inp_dep=()):
|
def inner_loop(first_range, inp_dep=()):
|
||||||
# inner unroll
|
# inner unroll
|
||||||
k = UOp.range(dtypes.int, BK, first_range+0)
|
k = UOp.range(BK, first_range+0)
|
||||||
|
|
||||||
# load from locals into registers
|
# load from locals into registers
|
||||||
iterWave = UOp.range(dtypes.int, nbIterWaveN, first_range+1)
|
iterWave = UOp.range(nbIterWaveN, first_range+1)
|
||||||
i = UOp.range(dtypes.int, TN, first_range+2)
|
i = UOp.range(TN, first_range+2)
|
||||||
index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i
|
index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i
|
||||||
B_row_store = B_row[iterWave*TN + i].store(Bs[k*BN + index].load(*inp_dep), iterWave, i)
|
B_row_store = B_row[iterWave*TN + i].store(Bs[k*BN + index].load(*inp_dep), iterWave, i)
|
||||||
|
|
||||||
iterWave = UOp.range(dtypes.int, nbIterWaveM, first_range+3)
|
iterWave = UOp.range(nbIterWaveM, first_range+3)
|
||||||
i = UOp.range(dtypes.int, TM, first_range+4)
|
i = UOp.range(TM, first_range+4)
|
||||||
index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i
|
index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i
|
||||||
A_col_store = A_col[iterWave*TM + i].store(As[k*BM_As_stride + index].load(*inp_dep), iterWave, i)
|
A_col_store = A_col[iterWave*TM + i].store(As[k*BM_As_stride + index].load(*inp_dep), iterWave, i)
|
||||||
|
|
||||||
# do the GEMM math
|
# do the GEMM math
|
||||||
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, first_range+5)
|
iterWaveM = UOp.range(nbIterWaveM, first_range+5)
|
||||||
yt = UOp.range(dtypes.int, TM, first_range+6)
|
yt = UOp.range(TM, first_range+6)
|
||||||
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, first_range+7)
|
iterWaveN = UOp.range(nbIterWaveN, first_range+7)
|
||||||
xt = UOp.range(dtypes.int, TN, first_range+8)
|
xt = UOp.range(TN, first_range+8)
|
||||||
x = iterWaveN * TN + xt
|
x = iterWaveN * TN + xt
|
||||||
y = iterWaveM * TM + yt
|
y = iterWaveM * TM + yt
|
||||||
c_regs_idx = c_regs[y * TN * nbIterWaveN + x]
|
c_regs_idx = c_regs[y * TN * nbIterWaveN + x]
|
||||||
@@ -256,12 +256,12 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)):
|
|||||||
sink = inner_loop(5, (barrier, regB_store, regA_store)).barrier()
|
sink = inner_loop(5, (barrier, regB_store, regA_store)).barrier()
|
||||||
|
|
||||||
# load from registers into locals
|
# load from registers into locals
|
||||||
i = UOp.range(dtypes.int, nbReadsB, 14)
|
i = UOp.range(nbReadsB, 14)
|
||||||
index_x = BN * blockIdx_x + rBIdx
|
index_x = BN * blockIdx_x + rBIdx
|
||||||
index_y = rBIdy + i * strideReadB + kId + BK
|
index_y = rBIdy + i * strideReadB + kId + BK
|
||||||
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(regB[i].load(sink), i, kId_range)
|
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(regB[i].load(sink), i, kId_range)
|
||||||
|
|
||||||
i = UOp.range(dtypes.int, nbReadsA, 15)
|
i = UOp.range(nbReadsA, 15)
|
||||||
index_x = rAIdx + kId + BK
|
index_x = rAIdx + kId + BK
|
||||||
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||||
As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(regA[i].load(sink), i, kId_range)
|
As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(regA[i].load(sink), i, kId_range)
|
||||||
@@ -269,40 +269,40 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)):
|
|||||||
# final iteration without the copy
|
# final iteration without the copy
|
||||||
sink = inner_loop(16, (UOp.barrier(Bs_store, As_store),))
|
sink = inner_loop(16, (UOp.barrier(Bs_store, As_store),))
|
||||||
else:
|
else:
|
||||||
kId_range = UOp.range(dtypes.int, N//BK, 0)
|
kId_range = UOp.range(N//BK, 0)
|
||||||
kId = kId_range*BK
|
kId = kId_range*BK
|
||||||
|
|
||||||
# load from globals into locals
|
# load from globals into locals
|
||||||
i = UOp.range(dtypes.int, nbReadsB, 1)
|
i = UOp.range(nbReadsB, 1)
|
||||||
index_x = BN * blockIdx_x + rBIdx
|
index_x = BN * blockIdx_x + rBIdx
|
||||||
index_y = rBIdy + i * strideReadB + kId
|
index_y = rBIdy + i * strideReadB + kId
|
||||||
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i)
|
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i)
|
||||||
|
|
||||||
i = UOp.range(dtypes.int, nbReadsA, 2)
|
i = UOp.range(nbReadsA, 2)
|
||||||
index_x = rAIdx + kId
|
index_x = rAIdx + kId
|
||||||
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||||
As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(a[N * index_y + index_x].load(), i)
|
As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(a[N * index_y + index_x].load(), i)
|
||||||
|
|
||||||
barrier = UOp.barrier(As_store, Bs_store)
|
barrier = UOp.barrier(As_store, Bs_store)
|
||||||
|
|
||||||
k = UOp.range(dtypes.int, BK, 3)
|
k = UOp.range(BK, 3)
|
||||||
|
|
||||||
# load from locals into registers
|
# load from locals into registers
|
||||||
iterWave = UOp.range(dtypes.int, nbIterWaveN, 4)
|
iterWave = UOp.range(nbIterWaveN, 4)
|
||||||
i = UOp.range(dtypes.int, TN, 5)
|
i = UOp.range(TN, 5)
|
||||||
index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i
|
index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i
|
||||||
B_row_store = B_row[iterWave*TN + i].store(Bs[k*BN + index].load(barrier), iterWave, i)
|
B_row_store = B_row[iterWave*TN + i].store(Bs[k*BN + index].load(barrier), iterWave, i)
|
||||||
|
|
||||||
iterWave = UOp.range(dtypes.int, nbIterWaveM, 6)
|
iterWave = UOp.range(nbIterWaveM, 6)
|
||||||
i = UOp.range(dtypes.int, TM, 7)
|
i = UOp.range(TM, 7)
|
||||||
index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i
|
index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i
|
||||||
A_col_store = A_col[iterWave*TM + i].store(As[k*BM_As_stride + index].load(barrier), iterWave, i)
|
A_col_store = A_col[iterWave*TM + i].store(As[k*BM_As_stride + index].load(barrier), iterWave, i)
|
||||||
|
|
||||||
# do the GEMM math
|
# do the GEMM math
|
||||||
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 8)
|
iterWaveM = UOp.range(nbIterWaveM, 8)
|
||||||
yt = UOp.range(dtypes.int, TM, 9)
|
yt = UOp.range(TM, 9)
|
||||||
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 10)
|
iterWaveN = UOp.range(nbIterWaveN, 10)
|
||||||
xt = UOp.range(dtypes.int, TN, 12)
|
xt = UOp.range(TN, 12)
|
||||||
x = iterWaveN * TN + xt
|
x = iterWaveN * TN + xt
|
||||||
y = iterWaveM * TM + yt
|
y = iterWaveM * TM + yt
|
||||||
c_regs_idx = c_regs[y * TN * nbIterWaveN + x]
|
c_regs_idx = c_regs[y * TN * nbIterWaveN + x]
|
||||||
@@ -310,10 +310,10 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)):
|
|||||||
iterWaveM, iterWaveN, yt, xt, k, kId_range)
|
iterWaveM, iterWaveN, yt, xt, k, kId_range)
|
||||||
|
|
||||||
# store c_regs into c
|
# store c_regs into c
|
||||||
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 1000)
|
iterWaveM = UOp.range(nbIterWaveM, 1000)
|
||||||
yt = UOp.range(dtypes.int, TM, 1001)
|
yt = UOp.range(TM, 1001)
|
||||||
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 1002)
|
iterWaveN = UOp.range(nbIterWaveN, 1002)
|
||||||
xt = UOp.range(dtypes.int, TN, 1003)
|
xt = UOp.range(TN, 1003)
|
||||||
xOut = blockIdx_x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave
|
xOut = blockIdx_x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave
|
||||||
yOut = blockIdx_y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave
|
yOut = blockIdx_y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave
|
||||||
indexC = N * (yOut + yt) + xOut + xt
|
indexC = N * (yOut + yt) + xOut + xt
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import unittest
|
import unittest
|
||||||
from tinygrad import Tensor, nn, Variable, UOp, dtypes
|
from tinygrad import Tensor, nn, Variable, UOp
|
||||||
|
|
||||||
# outerworld range should support three things
|
# outerworld range should support three things
|
||||||
# 1. full optimizer steps (test_model_bound_range)
|
# 1. full optimizer steps (test_model_bound_range)
|
||||||
@@ -136,7 +136,7 @@ class TestOuterworldRange(unittest.TestCase):
|
|||||||
def test_model_bound_range(self):
|
def test_model_bound_range(self):
|
||||||
m, opt = get_model_and_opt()
|
m, opt = get_model_and_opt()
|
||||||
# TODO: should ranges be unique so you don't have to pass in the -1?
|
# TODO: should ranges be unique so you don't have to pass in the -1?
|
||||||
rng = UOp.range(dtypes.int, self.STEPS, -1)
|
rng = UOp.range(self.STEPS, -1)
|
||||||
vib = Variable('i', 0, self.STEPS-1).bind(rng)
|
vib = Variable('i', 0, self.STEPS-1).bind(rng)
|
||||||
loss = (m(self.X[vib]) - self.Y[vib]).square().mean()
|
loss = (m(self.X[vib]) - self.Y[vib]).square().mean()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
from tinygrad import Tensor
|
from tinygrad import Tensor
|
||||||
from tinygrad.helpers import RANGEIFY, Context, GlobalCounters
|
from tinygrad.helpers import RANGEIFY, Context, GlobalCounters
|
||||||
|
from tinygrad.uop.ops import UOp
|
||||||
|
|
||||||
N = 256
|
N = 256
|
||||||
|
|
||||||
@@ -141,9 +142,6 @@ class TestRangeify(unittest.TestCase):
|
|||||||
print(f"mse: {mse}")
|
print(f"mse: {mse}")
|
||||||
self.assertLessEqual(mse, 1e-6)
|
self.assertLessEqual(mse, 1e-6)
|
||||||
|
|
||||||
from tinygrad import dtypes
|
|
||||||
from tinygrad.uop.ops import UOp
|
|
||||||
|
|
||||||
# contiguous + reduce can support ranges?
|
# contiguous + reduce can support ranges?
|
||||||
|
|
||||||
@unittest.skipIf(RANGEIFY<1, "tests only for RANGEIFY")
|
@unittest.skipIf(RANGEIFY<1, "tests only for RANGEIFY")
|
||||||
@@ -152,7 +150,7 @@ class TestOuterworld(unittest.TestCase):
|
|||||||
t = Tensor.rand(10, 10).realize()
|
t = Tensor.rand(10, 10).realize()
|
||||||
|
|
||||||
# passthrough ranges
|
# passthrough ranges
|
||||||
a = UOp.range(dtypes.int, 10, -1)
|
a = UOp.range(10, -1)
|
||||||
sel = t[a]
|
sel = t[a]
|
||||||
cpy = sel.contiguous(a).realize()
|
cpy = sel.contiguous(a).realize()
|
||||||
|
|
||||||
@@ -162,7 +160,7 @@ class TestOuterworld(unittest.TestCase):
|
|||||||
t = Tensor.rand(10, 10).realize()
|
t = Tensor.rand(10, 10).realize()
|
||||||
|
|
||||||
# passthrough ranges
|
# passthrough ranges
|
||||||
a = UOp.range(dtypes.int, 10, -1)
|
a = UOp.range(10, -1)
|
||||||
sel = t[9-a]
|
sel = t[9-a]
|
||||||
cpy = sel.contiguous(a).realize()
|
cpy = sel.contiguous(a).realize()
|
||||||
|
|
||||||
@@ -174,7 +172,7 @@ class TestOuterworld(unittest.TestCase):
|
|||||||
x = Tensor.ones(3, 10, 2).contiguous()
|
x = Tensor.ones(3, 10, 2).contiguous()
|
||||||
|
|
||||||
# vmap across axis 0
|
# vmap across axis 0
|
||||||
a = UOp.range(dtypes.int, 3, -1)
|
a = UOp.range(3, -1)
|
||||||
out = f(x[a])
|
out = f(x[a])
|
||||||
out = out.contiguous(a)
|
out = out.contiguous(a)
|
||||||
|
|
||||||
@@ -188,7 +186,7 @@ class TestOuterworld(unittest.TestCase):
|
|||||||
|
|
||||||
manual = (x @ W[0] @ W[1] @ W[2]).contiguous().realize()
|
manual = (x @ W[0] @ W[1] @ W[2]).contiguous().realize()
|
||||||
|
|
||||||
a = UOp.range(dtypes.int, 3, -1)
|
a = UOp.range(3, -1)
|
||||||
x = x.assign(x @ W[a])
|
x = x.assign(x @ W[a])
|
||||||
out = x.contiguous(a)[-1].contiguous().realize()
|
out = x.contiguous(a)[-1].contiguous().realize()
|
||||||
|
|
||||||
|
|||||||
@@ -477,7 +477,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||||||
|
|
||||||
def test_load_with_float_in_index(self):
|
def test_load_with_float_in_index(self):
|
||||||
with Context(IGNORE_OOB=0):
|
with Context(IGNORE_OOB=0):
|
||||||
ridx = UOp.range(dtypes.int, 20, 0)
|
ridx = UOp.range(20, 0)
|
||||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||||
i = (ridx.cast(dtypes.float)*0.68).trunc().cast(dtypes.int)
|
i = (ridx.cast(dtypes.float)*0.68).trunc().cast(dtypes.int)
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16))),))
|
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16))),))
|
||||||
@@ -490,7 +490,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||||||
def test_load_cast_to_bool(self):
|
def test_load_cast_to_bool(self):
|
||||||
with Context(IGNORE_OOB=0):
|
with Context(IGNORE_OOB=0):
|
||||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0)
|
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0)
|
||||||
ridx = UOp.range(dtypes.int, 20, 0)
|
ridx = UOp.range(20, 0)
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx, ridx.cast(dtypes.bool).logical_not()),))
|
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx, ridx.cast(dtypes.bool).logical_not()),))
|
||||||
to_uops_list([ld0])
|
to_uops_list([ld0])
|
||||||
|
|
||||||
@@ -499,7 +499,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||||||
with Context(IGNORE_OOB=0):
|
with Context(IGNORE_OOB=0):
|
||||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||||
mask = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
|
mask = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
|
||||||
ridx = UOp.range(dtypes.int, 20, 0)
|
ridx = UOp.range(20, 0)
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask),)))
|
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask),)))
|
||||||
to_uops_list([ld0])
|
to_uops_list([ld0])
|
||||||
|
|
||||||
@@ -592,8 +592,8 @@ class TestUOpGraph(unittest.TestCase):
|
|||||||
def test_switched_range_order(self):
|
def test_switched_range_order(self):
|
||||||
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||||
cf = UOp.const(dtypes.float, 0.0)
|
cf = UOp.const(dtypes.float, 0.0)
|
||||||
r1 = UOp.range(dtypes.int, 2, 0)
|
r1 = UOp.range(2, 0)
|
||||||
r2 = UOp.range(dtypes.int, 2, 1)
|
r2 = UOp.range(2, 1)
|
||||||
alu = UOp(Ops.MUL, dtypes.int, (r2, r1))
|
alu = UOp(Ops.MUL, dtypes.int, (r2, r1))
|
||||||
store = UOp(Ops.STORE, dtypes.void, (glbl.index(alu), cf))
|
store = UOp(Ops.STORE, dtypes.void, (glbl.index(alu), cf))
|
||||||
uops = to_uops_list([store])
|
uops = to_uops_list([store])
|
||||||
|
|||||||
@@ -403,7 +403,7 @@ class TestAssembly(unittest.TestCase):
|
|||||||
self.assertNotIn(Ops.IDIV, ops)
|
self.assertNotIn(Ops.IDIV, ops)
|
||||||
|
|
||||||
def test_fast_idiv_remove_powers_of_two(self):
|
def test_fast_idiv_remove_powers_of_two(self):
|
||||||
ridx = UOp.range(dtypes.int, 2**20, 0)
|
ridx = UOp.range(2**20, 0)
|
||||||
uops = to_uops_list([ridx//(7*64)], opts=Device[Device.DEFAULT].renderer)
|
uops = to_uops_list([ridx//(7*64)], opts=Device[Device.DEFAULT].renderer)
|
||||||
ops = [x.op for x in uops]
|
ops = [x.op for x in uops]
|
||||||
# this requires shifting out the powers of two before doing fast_idiv
|
# this requires shifting out the powers of two before doing fast_idiv
|
||||||
|
|||||||
@@ -65,21 +65,21 @@ class TestFoldingAndReduction(unittest.TestCase):
|
|||||||
def test_full_graph_rewrite_reduction_with_unused_range(self):
|
def test_full_graph_rewrite_reduction_with_unused_range(self):
|
||||||
const1 = UOp.const(dtypes.int32, 15)
|
const1 = UOp.const(dtypes.int32, 15)
|
||||||
const2 = UOp.const(dtypes.int32, 25)
|
const2 = UOp.const(dtypes.int32, 25)
|
||||||
rng = UOp.range(dtypes.int32, 10, idx=0)
|
rng = UOp.range(10, idx=0)
|
||||||
optimized_sink = apply_rewrite((const1 + const2).reduce(Ops.ADD, rng))
|
optimized_sink = apply_rewrite((const1 + const2).reduce(Ops.ADD, rng))
|
||||||
expected_sum = 10 * (15 + 25)
|
expected_sum = 10 * (15 + 25)
|
||||||
self.assertEqual(optimized_sink.arg, expected_sum)
|
self.assertEqual(optimized_sink.arg, expected_sum)
|
||||||
|
|
||||||
@unittest.skip("currently failing")
|
@unittest.skip("currently failing")
|
||||||
def test_full_graph_rewrite_range_reduction(self):
|
def test_full_graph_rewrite_range_reduction(self):
|
||||||
simple_range = UOp.range(dtypes.int32, 5, idx=0)
|
simple_range = UOp.range(5, idx=0)
|
||||||
optimized_sink = apply_rewrite(simple_range.reduce(Ops.ADD, simple_range))
|
optimized_sink = apply_rewrite(simple_range.reduce(Ops.ADD, simple_range))
|
||||||
expected_sum = sum(range(5))
|
expected_sum = sum(range(5))
|
||||||
self.assertEqual(optimized_sink.arg, expected_sum)
|
self.assertEqual(optimized_sink.arg, expected_sum)
|
||||||
|
|
||||||
@unittest.skip("currently failing")
|
@unittest.skip("currently failing")
|
||||||
def test_full_graph_rewrite_simple_reduction_folding(self):
|
def test_full_graph_rewrite_simple_reduction_folding(self):
|
||||||
simple_range = UOp.range(dtypes.int32, 4, idx=0)
|
simple_range = UOp.range(4, idx=0)
|
||||||
add_uop = simple_range + UOp.const(dtypes.int32, 1)
|
add_uop = simple_range + UOp.const(dtypes.int32, 1)
|
||||||
optimized_sink = apply_rewrite(add_uop.reduce(Ops.ADD, simple_range))
|
optimized_sink = apply_rewrite(add_uop.reduce(Ops.ADD, simple_range))
|
||||||
expected_sum = sum(i + 1 for i in range(4))
|
expected_sum = sum(i + 1 for i in range(4))
|
||||||
@@ -87,8 +87,8 @@ class TestFoldingAndReduction(unittest.TestCase):
|
|||||||
|
|
||||||
@unittest.skip("currently failing")
|
@unittest.skip("currently failing")
|
||||||
def test_full_graph_rewrite_nested_loop_collapse(self):
|
def test_full_graph_rewrite_nested_loop_collapse(self):
|
||||||
outer_range = UOp.range(dtypes.int32, 8, 0)
|
outer_range = UOp.range(8, 0)
|
||||||
inner_range = UOp.range(dtypes.int32, 4, 1)
|
inner_range = UOp.range(4, 1)
|
||||||
expr = (outer_range * 10) + inner_range
|
expr = (outer_range * 10) + inner_range
|
||||||
optimized_reduce_uop = apply_rewrite(expr.reduce(Ops.ADD, outer_range, inner_range))
|
optimized_reduce_uop = apply_rewrite(expr.reduce(Ops.ADD, outer_range, inner_range))
|
||||||
self.assertEqual(optimized_reduce_uop.op, Ops.CONST)
|
self.assertEqual(optimized_reduce_uop.op, Ops.CONST)
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UO
|
|||||||
|
|
||||||
def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.int, (), (expr, nmax))
|
def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.int, (), (expr, nmax))
|
||||||
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
|
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
|
||||||
def Range(n, nmax): return UOp.range(dtypes.int, nmax, n)
|
def Range(n, nmax): return UOp.range(nmax, n)
|
||||||
|
|
||||||
class TestHelpers(unittest.TestCase):
|
class TestHelpers(unittest.TestCase):
|
||||||
def test_is_increasing(self):
|
def test_is_increasing(self):
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
# the job of the lowerer is to do indexing
|
# the job of the lowerer is to do indexing
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from tinygrad.dtype import dtypes
|
|
||||||
from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop, AxisType, graph_rewrite, resolve
|
from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop, AxisType, graph_rewrite, resolve
|
||||||
|
|
||||||
# ***** indexing *****
|
# ***** indexing *****
|
||||||
@@ -12,7 +11,7 @@ class IndexContext:
|
|||||||
start: int = 0
|
start: int = 0
|
||||||
|
|
||||||
def shape_to_idx(s, axis_types, start=0):
|
def shape_to_idx(s, axis_types, start=0):
|
||||||
return [UOp.range(dtypes.int, sint_to_uop(s), start+i, at) for i, (s, at) in enumerate(zip(s, axis_types))]
|
return [UOp.range(sint_to_uop(s), start+i, at) for i, (s, at) in enumerate(zip(s, axis_types))]
|
||||||
|
|
||||||
def get_index(ast:UOp) -> IndexContext:
|
def get_index(ast:UOp) -> IndexContext:
|
||||||
axis_types = ast.arg.axis_types if isinstance(ast.arg, KernelInfo) else ()
|
axis_types = ast.arg.axis_types if isinstance(ast.arg, KernelInfo) else ()
|
||||||
|
|||||||
@@ -109,7 +109,7 @@ class RangeifyContext:
|
|||||||
# create ranges
|
# create ranges
|
||||||
range_idx: int = 0
|
range_idx: int = 0
|
||||||
def new_range(self, s:sint, axistype:AxisType=AxisType.LOOP):
|
def new_range(self, s:sint, axistype:AxisType=AxisType.LOOP):
|
||||||
ret = UOp.range(dtypes.int, s, self.range_idx, axistype)
|
ret = UOp.range(s, self.range_idx, axistype)
|
||||||
self.range_idx += 1
|
self.range_idx += 1
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|||||||
@@ -114,7 +114,7 @@ class View:
|
|||||||
|
|
||||||
def to_indexed_uops(self:View, idxs:Sequence[UOp]|None=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> tuple[UOp, UOp]:
|
def to_indexed_uops(self:View, idxs:Sequence[UOp]|None=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> tuple[UOp, UOp]:
|
||||||
"""(idx, valid)"""
|
"""(idx, valid)"""
|
||||||
if idxs is None: idxs = [UOp.range(dtypes.int, s, i) for i,s in enumerate(self.shape)]
|
if idxs is None: idxs = [UOp.range(s, i) for i,s in enumerate(self.shape)]
|
||||||
iexpr = sint_to_uop(self.offset)
|
iexpr = sint_to_uop(self.offset)
|
||||||
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else itertools.repeat(None)):
|
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else itertools.repeat(None)):
|
||||||
if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st
|
if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st
|
||||||
|
|||||||
@@ -297,10 +297,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||||||
else: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),))
|
else: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),))
|
||||||
return ret
|
return ret
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def range(dtype:DType, end:sint, *arg):
|
def range(end:sint, *arg):
|
||||||
if len(arg) == 0: raise RuntimeError("range needs an arg")
|
if len(arg) == 0: raise RuntimeError("range needs an arg")
|
||||||
if len(arg) == 1: arg = arg+(AxisType.LOOP,)
|
if len(arg) == 1: arg = arg+(AxisType.LOOP,)
|
||||||
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end),), arg=arg)
|
return UOp(Ops.RANGE, dtype=dtypes.int, src=(sint_to_uop(end),), arg=arg)
|
||||||
def r(self, op:Ops, axis:tuple[int, ...]):
|
def r(self, op:Ops, axis:tuple[int, ...]):
|
||||||
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
|
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
|
||||||
if len(axis) == 0: return self
|
if len(axis) == 0: return self
|
||||||
|
|||||||
Reference in New Issue
Block a user