mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
remove dtype from range, it will be dtypes.index soon [pr] (#11914)
* remove dtype from range, it will be dtypes.index soon [pr] * a few more
This commit is contained in:
@@ -65,7 +65,7 @@ def top_spec_kernel3():
|
||||
c = a@b
|
||||
sink = c.schedule()[-1].ast
|
||||
L = 16
|
||||
sink = sink.reshape((N//L, L, N//L, L)) #.lift({0:UOp.range(dtypes.int, N//BM, 0), 2:UOp.range(dtypes.int, N//BN, 1)})
|
||||
sink = sink.reshape((N//L, L, N//L, L)) #.lift({0:UOp.range(N//BM, 0), 2:UOp.range(N//BN, 1)})
|
||||
sink = graph_rewrite(sink, view_left+pm)
|
||||
axis_types = (AxisType.GLOBAL, AxisType.LOCAL, AxisType.GLOBAL, AxisType.LOCAL, AxisType.REDUCE)
|
||||
return sink.replace(arg=KernelInfo(name="top_"+to_colored(sink.full_shape, axis_types), axis_types=axis_types))
|
||||
@@ -186,7 +186,7 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)):
|
||||
|
||||
c_regs = UOp(Ops.DEFINE_REG, dtypes.float.ptr(TM * nbIterWaveM * TN * nbIterWaveN), arg=2)
|
||||
|
||||
i = UOp.range(dtypes.int, c_regs.dtype.size, 16)
|
||||
i = UOp.range(c_regs.dtype.size, 16)
|
||||
init_store = c_regs[i].store(UOp.const(dtypes.float, 0.0), i)
|
||||
|
||||
if kernel4:
|
||||
@@ -197,53 +197,53 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)):
|
||||
kId = 0
|
||||
|
||||
# load from globals into locals
|
||||
i = UOp.range(dtypes.int, nbReadsB, 0)
|
||||
i = UOp.range(nbReadsB, 0)
|
||||
index_x = BN * blockIdx_x + rBIdx
|
||||
index_y = rBIdy + i * strideReadB + kId
|
||||
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i)
|
||||
|
||||
i = UOp.range(dtypes.int, nbReadsA, 1)
|
||||
i = UOp.range(nbReadsA, 1)
|
||||
index_x = rAIdx + kId
|
||||
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||
As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(a[N * index_y + index_x].load(), i)
|
||||
|
||||
# iterate over the middle chunk
|
||||
kId_range = UOp.range(dtypes.int, N//BK-1, 2)
|
||||
kId_range = UOp.range(N//BK-1, 2)
|
||||
kId = kId_range*BK
|
||||
|
||||
barrier = UOp.barrier(As_store, Bs_store)
|
||||
|
||||
# load from globals into registers (next round)
|
||||
i = UOp.range(dtypes.int, nbReadsB, 3)
|
||||
i = UOp.range(nbReadsB, 3)
|
||||
index_x = BN * blockIdx_x + rBIdx
|
||||
index_y = rBIdy + i * strideReadB + kId + BK
|
||||
regB_store = regB[i].store(b[N * index_y + index_x].load(), i)
|
||||
|
||||
i = UOp.range(dtypes.int, nbReadsA, 4)
|
||||
i = UOp.range(nbReadsA, 4)
|
||||
index_x = rAIdx + kId + BK
|
||||
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||
regA_store = regA[i].store(a[N * index_y + index_x].load(), i)
|
||||
|
||||
def inner_loop(first_range, inp_dep=()):
|
||||
# inner unroll
|
||||
k = UOp.range(dtypes.int, BK, first_range+0)
|
||||
k = UOp.range(BK, first_range+0)
|
||||
|
||||
# load from locals into registers
|
||||
iterWave = UOp.range(dtypes.int, nbIterWaveN, first_range+1)
|
||||
i = UOp.range(dtypes.int, TN, first_range+2)
|
||||
iterWave = UOp.range(nbIterWaveN, first_range+1)
|
||||
i = UOp.range(TN, first_range+2)
|
||||
index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i
|
||||
B_row_store = B_row[iterWave*TN + i].store(Bs[k*BN + index].load(*inp_dep), iterWave, i)
|
||||
|
||||
iterWave = UOp.range(dtypes.int, nbIterWaveM, first_range+3)
|
||||
i = UOp.range(dtypes.int, TM, first_range+4)
|
||||
iterWave = UOp.range(nbIterWaveM, first_range+3)
|
||||
i = UOp.range(TM, first_range+4)
|
||||
index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i
|
||||
A_col_store = A_col[iterWave*TM + i].store(As[k*BM_As_stride + index].load(*inp_dep), iterWave, i)
|
||||
|
||||
# do the GEMM math
|
||||
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, first_range+5)
|
||||
yt = UOp.range(dtypes.int, TM, first_range+6)
|
||||
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, first_range+7)
|
||||
xt = UOp.range(dtypes.int, TN, first_range+8)
|
||||
iterWaveM = UOp.range(nbIterWaveM, first_range+5)
|
||||
yt = UOp.range(TM, first_range+6)
|
||||
iterWaveN = UOp.range(nbIterWaveN, first_range+7)
|
||||
xt = UOp.range(TN, first_range+8)
|
||||
x = iterWaveN * TN + xt
|
||||
y = iterWaveM * TM + yt
|
||||
c_regs_idx = c_regs[y * TN * nbIterWaveN + x]
|
||||
@@ -256,12 +256,12 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)):
|
||||
sink = inner_loop(5, (barrier, regB_store, regA_store)).barrier()
|
||||
|
||||
# load from registers into locals
|
||||
i = UOp.range(dtypes.int, nbReadsB, 14)
|
||||
i = UOp.range(nbReadsB, 14)
|
||||
index_x = BN * blockIdx_x + rBIdx
|
||||
index_y = rBIdy + i * strideReadB + kId + BK
|
||||
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(regB[i].load(sink), i, kId_range)
|
||||
|
||||
i = UOp.range(dtypes.int, nbReadsA, 15)
|
||||
i = UOp.range(nbReadsA, 15)
|
||||
index_x = rAIdx + kId + BK
|
||||
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||
As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(regA[i].load(sink), i, kId_range)
|
||||
@@ -269,40 +269,40 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)):
|
||||
# final iteration without the copy
|
||||
sink = inner_loop(16, (UOp.barrier(Bs_store, As_store),))
|
||||
else:
|
||||
kId_range = UOp.range(dtypes.int, N//BK, 0)
|
||||
kId_range = UOp.range(N//BK, 0)
|
||||
kId = kId_range*BK
|
||||
|
||||
# load from globals into locals
|
||||
i = UOp.range(dtypes.int, nbReadsB, 1)
|
||||
i = UOp.range(nbReadsB, 1)
|
||||
index_x = BN * blockIdx_x + rBIdx
|
||||
index_y = rBIdy + i * strideReadB + kId
|
||||
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i)
|
||||
|
||||
i = UOp.range(dtypes.int, nbReadsA, 2)
|
||||
i = UOp.range(nbReadsA, 2)
|
||||
index_x = rAIdx + kId
|
||||
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||
As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(a[N * index_y + index_x].load(), i)
|
||||
|
||||
barrier = UOp.barrier(As_store, Bs_store)
|
||||
|
||||
k = UOp.range(dtypes.int, BK, 3)
|
||||
k = UOp.range(BK, 3)
|
||||
|
||||
# load from locals into registers
|
||||
iterWave = UOp.range(dtypes.int, nbIterWaveN, 4)
|
||||
i = UOp.range(dtypes.int, TN, 5)
|
||||
iterWave = UOp.range(nbIterWaveN, 4)
|
||||
i = UOp.range(TN, 5)
|
||||
index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i
|
||||
B_row_store = B_row[iterWave*TN + i].store(Bs[k*BN + index].load(barrier), iterWave, i)
|
||||
|
||||
iterWave = UOp.range(dtypes.int, nbIterWaveM, 6)
|
||||
i = UOp.range(dtypes.int, TM, 7)
|
||||
iterWave = UOp.range(nbIterWaveM, 6)
|
||||
i = UOp.range(TM, 7)
|
||||
index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i
|
||||
A_col_store = A_col[iterWave*TM + i].store(As[k*BM_As_stride + index].load(barrier), iterWave, i)
|
||||
|
||||
# do the GEMM math
|
||||
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 8)
|
||||
yt = UOp.range(dtypes.int, TM, 9)
|
||||
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 10)
|
||||
xt = UOp.range(dtypes.int, TN, 12)
|
||||
iterWaveM = UOp.range(nbIterWaveM, 8)
|
||||
yt = UOp.range(TM, 9)
|
||||
iterWaveN = UOp.range(nbIterWaveN, 10)
|
||||
xt = UOp.range(TN, 12)
|
||||
x = iterWaveN * TN + xt
|
||||
y = iterWaveM * TM + yt
|
||||
c_regs_idx = c_regs[y * TN * nbIterWaveN + x]
|
||||
@@ -310,10 +310,10 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)):
|
||||
iterWaveM, iterWaveN, yt, xt, k, kId_range)
|
||||
|
||||
# store c_regs into c
|
||||
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 1000)
|
||||
yt = UOp.range(dtypes.int, TM, 1001)
|
||||
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 1002)
|
||||
xt = UOp.range(dtypes.int, TN, 1003)
|
||||
iterWaveM = UOp.range(nbIterWaveM, 1000)
|
||||
yt = UOp.range(TM, 1001)
|
||||
iterWaveN = UOp.range(nbIterWaveN, 1002)
|
||||
xt = UOp.range(TN, 1003)
|
||||
xOut = blockIdx_x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave
|
||||
yOut = blockIdx_y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave
|
||||
indexC = N * (yOut + yt) + xOut + xt
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, nn, Variable, UOp, dtypes
|
||||
from tinygrad import Tensor, nn, Variable, UOp
|
||||
|
||||
# outerworld range should support three things
|
||||
# 1. full optimizer steps (test_model_bound_range)
|
||||
@@ -136,7 +136,7 @@ class TestOuterworldRange(unittest.TestCase):
|
||||
def test_model_bound_range(self):
|
||||
m, opt = get_model_and_opt()
|
||||
# TODO: should ranges be unique so you don't have to pass in the -1?
|
||||
rng = UOp.range(dtypes.int, self.STEPS, -1)
|
||||
rng = UOp.range(self.STEPS, -1)
|
||||
vib = Variable('i', 0, self.STEPS-1).bind(rng)
|
||||
loss = (m(self.X[vib]) - self.Y[vib]).square().mean()
|
||||
loss.backward()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.helpers import RANGEIFY, Context, GlobalCounters
|
||||
from tinygrad.uop.ops import UOp
|
||||
|
||||
N = 256
|
||||
|
||||
@@ -141,9 +142,6 @@ class TestRangeify(unittest.TestCase):
|
||||
print(f"mse: {mse}")
|
||||
self.assertLessEqual(mse, 1e-6)
|
||||
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.uop.ops import UOp
|
||||
|
||||
# contiguous + reduce can support ranges?
|
||||
|
||||
@unittest.skipIf(RANGEIFY<1, "tests only for RANGEIFY")
|
||||
@@ -152,7 +150,7 @@ class TestOuterworld(unittest.TestCase):
|
||||
t = Tensor.rand(10, 10).realize()
|
||||
|
||||
# passthrough ranges
|
||||
a = UOp.range(dtypes.int, 10, -1)
|
||||
a = UOp.range(10, -1)
|
||||
sel = t[a]
|
||||
cpy = sel.contiguous(a).realize()
|
||||
|
||||
@@ -162,7 +160,7 @@ class TestOuterworld(unittest.TestCase):
|
||||
t = Tensor.rand(10, 10).realize()
|
||||
|
||||
# passthrough ranges
|
||||
a = UOp.range(dtypes.int, 10, -1)
|
||||
a = UOp.range(10, -1)
|
||||
sel = t[9-a]
|
||||
cpy = sel.contiguous(a).realize()
|
||||
|
||||
@@ -174,7 +172,7 @@ class TestOuterworld(unittest.TestCase):
|
||||
x = Tensor.ones(3, 10, 2).contiguous()
|
||||
|
||||
# vmap across axis 0
|
||||
a = UOp.range(dtypes.int, 3, -1)
|
||||
a = UOp.range(3, -1)
|
||||
out = f(x[a])
|
||||
out = out.contiguous(a)
|
||||
|
||||
@@ -188,7 +186,7 @@ class TestOuterworld(unittest.TestCase):
|
||||
|
||||
manual = (x @ W[0] @ W[1] @ W[2]).contiguous().realize()
|
||||
|
||||
a = UOp.range(dtypes.int, 3, -1)
|
||||
a = UOp.range(3, -1)
|
||||
x = x.assign(x @ W[a])
|
||||
out = x.contiguous(a)[-1].contiguous().realize()
|
||||
|
||||
|
||||
@@ -477,7 +477,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
|
||||
def test_load_with_float_in_index(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
ridx = UOp.range(dtypes.int, 20, 0)
|
||||
ridx = UOp.range(20, 0)
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
i = (ridx.cast(dtypes.float)*0.68).trunc().cast(dtypes.int)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16))),))
|
||||
@@ -490,7 +490,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
def test_load_cast_to_bool(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0)
|
||||
ridx = UOp.range(dtypes.int, 20, 0)
|
||||
ridx = UOp.range(20, 0)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx, ridx.cast(dtypes.bool).logical_not()),))
|
||||
to_uops_list([ld0])
|
||||
|
||||
@@ -499,7 +499,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
mask = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
|
||||
ridx = UOp.range(dtypes.int, 20, 0)
|
||||
ridx = UOp.range(20, 0)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask),)))
|
||||
to_uops_list([ld0])
|
||||
|
||||
@@ -592,8 +592,8 @@ class TestUOpGraph(unittest.TestCase):
|
||||
def test_switched_range_order(self):
|
||||
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
cf = UOp.const(dtypes.float, 0.0)
|
||||
r1 = UOp.range(dtypes.int, 2, 0)
|
||||
r2 = UOp.range(dtypes.int, 2, 1)
|
||||
r1 = UOp.range(2, 0)
|
||||
r2 = UOp.range(2, 1)
|
||||
alu = UOp(Ops.MUL, dtypes.int, (r2, r1))
|
||||
store = UOp(Ops.STORE, dtypes.void, (glbl.index(alu), cf))
|
||||
uops = to_uops_list([store])
|
||||
|
||||
@@ -403,7 +403,7 @@ class TestAssembly(unittest.TestCase):
|
||||
self.assertNotIn(Ops.IDIV, ops)
|
||||
|
||||
def test_fast_idiv_remove_powers_of_two(self):
|
||||
ridx = UOp.range(dtypes.int, 2**20, 0)
|
||||
ridx = UOp.range(2**20, 0)
|
||||
uops = to_uops_list([ridx//(7*64)], opts=Device[Device.DEFAULT].renderer)
|
||||
ops = [x.op for x in uops]
|
||||
# this requires shifting out the powers of two before doing fast_idiv
|
||||
|
||||
@@ -65,21 +65,21 @@ class TestFoldingAndReduction(unittest.TestCase):
|
||||
def test_full_graph_rewrite_reduction_with_unused_range(self):
|
||||
const1 = UOp.const(dtypes.int32, 15)
|
||||
const2 = UOp.const(dtypes.int32, 25)
|
||||
rng = UOp.range(dtypes.int32, 10, idx=0)
|
||||
rng = UOp.range(10, idx=0)
|
||||
optimized_sink = apply_rewrite((const1 + const2).reduce(Ops.ADD, rng))
|
||||
expected_sum = 10 * (15 + 25)
|
||||
self.assertEqual(optimized_sink.arg, expected_sum)
|
||||
|
||||
@unittest.skip("currently failing")
|
||||
def test_full_graph_rewrite_range_reduction(self):
|
||||
simple_range = UOp.range(dtypes.int32, 5, idx=0)
|
||||
simple_range = UOp.range(5, idx=0)
|
||||
optimized_sink = apply_rewrite(simple_range.reduce(Ops.ADD, simple_range))
|
||||
expected_sum = sum(range(5))
|
||||
self.assertEqual(optimized_sink.arg, expected_sum)
|
||||
|
||||
@unittest.skip("currently failing")
|
||||
def test_full_graph_rewrite_simple_reduction_folding(self):
|
||||
simple_range = UOp.range(dtypes.int32, 4, idx=0)
|
||||
simple_range = UOp.range(4, idx=0)
|
||||
add_uop = simple_range + UOp.const(dtypes.int32, 1)
|
||||
optimized_sink = apply_rewrite(add_uop.reduce(Ops.ADD, simple_range))
|
||||
expected_sum = sum(i + 1 for i in range(4))
|
||||
@@ -87,8 +87,8 @@ class TestFoldingAndReduction(unittest.TestCase):
|
||||
|
||||
@unittest.skip("currently failing")
|
||||
def test_full_graph_rewrite_nested_loop_collapse(self):
|
||||
outer_range = UOp.range(dtypes.int32, 8, 0)
|
||||
inner_range = UOp.range(dtypes.int32, 4, 1)
|
||||
outer_range = UOp.range(8, 0)
|
||||
inner_range = UOp.range(4, 1)
|
||||
expr = (outer_range * 10) + inner_range
|
||||
optimized_reduce_uop = apply_rewrite(expr.reduce(Ops.ADD, outer_range, inner_range))
|
||||
self.assertEqual(optimized_reduce_uop.op, Ops.CONST)
|
||||
|
||||
@@ -19,7 +19,7 @@ def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UO
|
||||
|
||||
def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.int, (), (expr, nmax))
|
||||
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
|
||||
def Range(n, nmax): return UOp.range(dtypes.int, nmax, n)
|
||||
def Range(n, nmax): return UOp.range(nmax, n)
|
||||
|
||||
class TestHelpers(unittest.TestCase):
|
||||
def test_is_increasing(self):
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# the job of the lowerer is to do indexing
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop, AxisType, graph_rewrite, resolve
|
||||
|
||||
# ***** indexing *****
|
||||
@@ -12,7 +11,7 @@ class IndexContext:
|
||||
start: int = 0
|
||||
|
||||
def shape_to_idx(s, axis_types, start=0):
|
||||
return [UOp.range(dtypes.int, sint_to_uop(s), start+i, at) for i, (s, at) in enumerate(zip(s, axis_types))]
|
||||
return [UOp.range(sint_to_uop(s), start+i, at) for i, (s, at) in enumerate(zip(s, axis_types))]
|
||||
|
||||
def get_index(ast:UOp) -> IndexContext:
|
||||
axis_types = ast.arg.axis_types if isinstance(ast.arg, KernelInfo) else ()
|
||||
|
||||
@@ -109,7 +109,7 @@ class RangeifyContext:
|
||||
# create ranges
|
||||
range_idx: int = 0
|
||||
def new_range(self, s:sint, axistype:AxisType=AxisType.LOOP):
|
||||
ret = UOp.range(dtypes.int, s, self.range_idx, axistype)
|
||||
ret = UOp.range(s, self.range_idx, axistype)
|
||||
self.range_idx += 1
|
||||
return ret
|
||||
|
||||
|
||||
@@ -114,7 +114,7 @@ class View:
|
||||
|
||||
def to_indexed_uops(self:View, idxs:Sequence[UOp]|None=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> tuple[UOp, UOp]:
|
||||
"""(idx, valid)"""
|
||||
if idxs is None: idxs = [UOp.range(dtypes.int, s, i) for i,s in enumerate(self.shape)]
|
||||
if idxs is None: idxs = [UOp.range(s, i) for i,s in enumerate(self.shape)]
|
||||
iexpr = sint_to_uop(self.offset)
|
||||
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else itertools.repeat(None)):
|
||||
if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st
|
||||
|
||||
@@ -297,10 +297,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
else: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),))
|
||||
return ret
|
||||
@staticmethod
|
||||
def range(dtype:DType, end:sint, *arg):
|
||||
def range(end:sint, *arg):
|
||||
if len(arg) == 0: raise RuntimeError("range needs an arg")
|
||||
if len(arg) == 1: arg = arg+(AxisType.LOOP,)
|
||||
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end),), arg=arg)
|
||||
return UOp(Ops.RANGE, dtype=dtypes.int, src=(sint_to_uop(end),), arg=arg)
|
||||
def r(self, op:Ops, axis:tuple[int, ...]):
|
||||
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
|
||||
if len(axis) == 0: return self
|
||||
|
||||
Reference in New Issue
Block a user