mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
dtypes.index -> dtypes.weakint (#15377)
This commit is contained in:
@@ -2666,10 +2666,10 @@ def custom_uop_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
|
||||
m = UOp.range(M, 1, AxisType.LOOP)
|
||||
n = UOp.range(N, 2, AxisType.LOOP)
|
||||
k = UOp.range(K, 0, AxisType.REDUCE)
|
||||
mul = (A.flatten().index((m*UOp.const(dtypes.index, K)+k))*
|
||||
B.flatten().index((k*UOp.const(dtypes.index, N)+n))).cast(dtypes.float32)
|
||||
mul = (A.flatten().index((m*UOp.const(dtypes.weakint, K)+k))*
|
||||
B.flatten().index((k*UOp.const(dtypes.weakint, N)+n))).cast(dtypes.float32)
|
||||
red = mul.reduce(k, arg=Ops.ADD, dtype=dtypes.float32).cast(C.dtype.base)
|
||||
store = C.flatten().index((m*UOp.const(dtypes.index, N)+n), ptr=True).store(red).end(m, n)
|
||||
store = C.flatten().index((m*UOp.const(dtypes.weakint, N)+n), ptr=True).store(red).end(m, n)
|
||||
return store.sink(arg=KernelInfo(name=f'uop_gemm_{M}_{N}_{K}'))
|
||||
|
||||
# ** backward gemm, might use the asm gemm
|
||||
|
||||
@@ -12,18 +12,18 @@ class TestLinearizerFailure(unittest.TestCase):
|
||||
@unittest.skipUnless(Device.DEFAULT == "METAL", "only tested on METAL")
|
||||
def test_failure_beam_mnist(self):
|
||||
c0 = UOp(Ops.PARAM, dtypes.uchar.ptr(4014080), arg=0, src=())
|
||||
c1 = UOp.range(UOp.const(dtypes.index, 512), 0, AxisType.GLOBAL)
|
||||
c2 = UOp.range(UOp.const(dtypes.index, 784), 1, AxisType.GLOBAL)
|
||||
c3 = UOp.range(UOp.const(dtypes.index, 10), 3, AxisType.GLOBAL)
|
||||
c1 = UOp.range(UOp.const(dtypes.weakint, 512), 0, AxisType.GLOBAL)
|
||||
c2 = UOp.range(UOp.const(dtypes.weakint, 784), 1, AxisType.GLOBAL)
|
||||
c3 = UOp.range(UOp.const(dtypes.weakint, 10), 3, AxisType.GLOBAL)
|
||||
c4 = UOp(Ops.PARAM, dtypes.int.ptr(512), arg=1, src=())
|
||||
c5 = c4.index(c1.valid(UOp.const(dtypes.bool, True)))
|
||||
c6 = UOp.range(UOp.const(dtypes.index, 6000), 1004, AxisType.REDUCE)
|
||||
c7 = UOp.range(UOp.const(dtypes.index, 3750), 2006, AxisType.REDUCE)
|
||||
c8 = UOp.range(UOp.const(dtypes.index, 16), 2007, AxisType.GROUP_REDUCE)
|
||||
c6 = UOp.range(UOp.const(dtypes.weakint, 6000), 1004, AxisType.REDUCE)
|
||||
c7 = UOp.range(UOp.const(dtypes.weakint, 3750), 2006, AxisType.REDUCE)
|
||||
c8 = UOp.range(UOp.const(dtypes.weakint, 16), 2007, AxisType.GROUP_REDUCE)
|
||||
c9 = UOp(Ops.PARAM, dtypes.uchar.ptr(47040000), arg=2, src=())
|
||||
c10 = c9.index((((c3*UOp.const(dtypes.index, 4704000))+c2)+(c6*UOp.const(dtypes.index, 784))).valid(UOp.const(dtypes.bool, True)))
|
||||
c11 = c5.alu(Ops.CMPNE, ((((c3*UOp.const(dtypes.index, 6000))+c6)+((c7*UOp.const(dtypes.index, 16))+c8)).alu(Ops.CMPLT, UOp.const(dtypes.index, 59999)).where(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)).reduce(c7, c8, arg=Ops.ADD)+UOp.const(dtypes.int, -1))).where(UOp.const(dtypes.uchar, 0), c10).reduce(c6, arg=Ops.ADD)
|
||||
c12 = c0.index((((c1*UOp.const(dtypes.index, 7840))+(c2*UOp.const(dtypes.index, 10)))+c3).valid(UOp.const(dtypes.bool, True))).store(c11).end(c1, c2, c3)
|
||||
c10 = c9.index((((c3*UOp.const(dtypes.weakint, 4704000))+c2)+(c6*UOp.const(dtypes.weakint, 784))).valid(UOp.const(dtypes.bool, True)))
|
||||
c11 = c5.alu(Ops.CMPNE, ((((c3*UOp.const(dtypes.weakint, 6000))+c6)+((c7*UOp.const(dtypes.weakint, 16))+c8)).alu(Ops.CMPLT, UOp.const(dtypes.weakint, 59999)).where(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)).reduce(c7, c8, arg=Ops.ADD)+UOp.const(dtypes.int, -1))).where(UOp.const(dtypes.uchar, 0), c10).reduce(c6, arg=Ops.ADD)
|
||||
c12 = c0.index((((c1*UOp.const(dtypes.weakint, 7840))+(c2*UOp.const(dtypes.weakint, 10)))+c3).valid(UOp.const(dtypes.bool, True))).store(c11).end(c1, c2, c3)
|
||||
ast = c12.sink(arg=KernelInfo(name='test', axis_types=(), dont_use_locals=False, applied_opts=(Opt(op=OptOps.GROUP, axis=1, arg=16),), opts_to_apply=None))
|
||||
_ = get_program(ast, Device["METAL"].renderer)
|
||||
|
||||
|
||||
4
test/external/external_benchmark_op_conv.py
vendored
4
test/external/external_benchmark_op_conv.py
vendored
@@ -23,7 +23,7 @@ def vision_conv_143():
|
||||
c32 = ((c27<3)!=True)&(c27<67)
|
||||
c34 = UOp(Ops.PARAM, dtypes.imageh((32, 1024, 4)), (), 1)
|
||||
c38 = c5//2
|
||||
c45 = (c32&c24).where((c27*64+c38+c17*4096+-12480), UOp.const(dtypes.index, Invalid))
|
||||
c45 = (c32&c24).where((c27*64+c38+c17*4096+-12480), UOp.const(dtypes.weakint, Invalid))
|
||||
c48 = (c24&c32).where(c34.index(c45), UOp.const(dtypes.float, 0.0))
|
||||
c49 = UOp(Ops.PARAM, dtypes.imageh((64, 49, 4)), (), 2)
|
||||
c61 = c48*c49.index((c26*4+c5%2+c16*28+c38*196))
|
||||
@@ -49,7 +49,7 @@ def vision_conv_153():
|
||||
c32 = ((c27<3)!=True)&(c27<35)
|
||||
c34 = UOp(Ops.PARAM, dtypes.imageh((16, 1024, 4)), (), 1)
|
||||
c38 = c5//2
|
||||
c45 = (c32&c24).where((c27*128+c38+c17*4096+-12672), UOp.const(dtypes.index, Invalid))
|
||||
c45 = (c32&c24).where((c27*128+c38+c17*4096+-12672), UOp.const(dtypes.weakint, Invalid))
|
||||
c48 = (c24&c32).where(c34.index(c45), UOp.const(dtypes.float, 0.0))
|
||||
c49 = UOp(Ops.PARAM, dtypes.imageh((128, 49, 4)), (), 2)
|
||||
c61 = c48*c49.index((c26*4+c5%2+c16*28+c38*196))
|
||||
|
||||
@@ -24,7 +24,7 @@ class TestGroupedDims(unittest.TestCase):
|
||||
total = math.prod(dims)
|
||||
specials = sorted(dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs])), key=lambda u: u.arg)
|
||||
# build flat index and primed flat (same expression with renamed SPECIALs)
|
||||
flat = UOp.const(dtypes.index, 0)
|
||||
flat = UOp.const(dtypes.weakint, 0)
|
||||
for i, idx in enumerate(idxs):
|
||||
flat = flat + idx * int(math.prod(dims[i+1:]))
|
||||
flat_p = flat.substitute({s: UOp(Ops.SPECIAL, s.dtype, s.src, s.arg+"_p") for s in specials})
|
||||
|
||||
@@ -99,21 +99,21 @@ class TestFoldingAndReduction(unittest.TestCase):
|
||||
class TestModuloAndDivisionFolding(unittest.TestCase):
|
||||
def test_full_graph_rewrite_modulo_folding_with_define_var(self):
|
||||
# index dtype because div-mod rules only work on index
|
||||
x_var_uop = UOp.variable('x', 0, 100).cast(dtypes.index)
|
||||
x_var_uop = UOp.variable('x', 0, 100).cast(dtypes.weakint)
|
||||
optimized_mod_uop = apply_rewrite(((x_var_uop * 4) + 2) % 4)
|
||||
self.assertEqual(optimized_mod_uop.op, Ops.CONST)
|
||||
self.assertEqual(optimized_mod_uop.arg, 2)
|
||||
|
||||
def test_full_graph_rewrite_division_folding_with_define_var(self):
|
||||
# index dtype because div-mod rules only work on index
|
||||
n_var_uop = UOp.variable('n', 1, 1000).cast(dtypes.index)
|
||||
n_var_uop = UOp.variable('n', 1, 1000).cast(dtypes.weakint)
|
||||
optimized_div_uop = apply_rewrite((n_var_uop * 6) // 3)
|
||||
self.assertEqual(optimized_div_uop.op, Ops.MUL)
|
||||
self.assertEqual(optimized_div_uop.src[1].arg, 2)
|
||||
|
||||
def test_full_graph_rewrite_complex_mod_div_folding(self):
|
||||
# index dtype because div-mod rules only work on index
|
||||
k_var_uop = UOp.variable('k', 0, 50).cast(dtypes.index)
|
||||
k_var_uop = UOp.variable('k', 0, 50).cast(dtypes.weakint)
|
||||
optimized_div_uop = apply_rewrite(((k_var_uop * 12 + 8) % 6) // 2)
|
||||
self.assertEqual(optimized_div_uop.op, Ops.CONST)
|
||||
self.assertEqual(optimized_div_uop.arg, 1)
|
||||
@@ -132,7 +132,7 @@ class TestModuloAndDivisionFolding(unittest.TestCase):
|
||||
def test_full_graph_rewrite_modulo_large_divisor(self):
|
||||
# index dtype because div-mod rules only work on index
|
||||
x_var_uop = UOp.variable('x', 1, 5)
|
||||
self.assertIs(apply_rewrite(x_var_uop.cast(dtypes.index) % 10).render(simplify=False), x_var_uop.render(simplify=False))
|
||||
self.assertIs(apply_rewrite(x_var_uop.cast(dtypes.weakint) % 10).render(simplify=False), x_var_uop.render(simplify=False))
|
||||
|
||||
def test_full_graph_rewrite_division_with_remainder(self):
|
||||
x_var_uop = UOp.variable('x', 7, 9)
|
||||
|
||||
@@ -8,12 +8,12 @@ from tinygrad.device import Device
|
||||
class TestLinearizerFailures(unittest.TestCase):
|
||||
def test_fail_1(self):
|
||||
c0 = UOp(Ops.PARAM, dtypes.float.ptr(64), arg=0, src=())
|
||||
c1 = UOp.range(UOp.const(dtypes.index, 2), 1, AxisType.LOOP)
|
||||
c2 = UOp.range(UOp.const(dtypes.index, 32), 2, AxisType.LOOP)
|
||||
c3 = ((c1*UOp.const(dtypes.index, 32))+c2)
|
||||
c1 = UOp.range(UOp.const(dtypes.weakint, 2), 1, AxisType.LOOP)
|
||||
c2 = UOp.range(UOp.const(dtypes.weakint, 32), 2, AxisType.LOOP)
|
||||
c3 = ((c1*UOp.const(dtypes.weakint, 32))+c2)
|
||||
c4 = UOp(Ops.PARAM, dtypes.float.ptr(163840), arg=1, src=())
|
||||
c5 = UOp.range(UOp.const(dtypes.index, 2560), 0, AxisType.REDUCE)
|
||||
c6 = c4.index(((((((c5//UOp.const(dtypes.index, 8))%UOp.const(dtypes.index, 8))*UOp.const(dtypes.index, 8))+(c5%UOp.const(dtypes.index, 8)))+(((c2*UOp.const(dtypes.index, 40))+(c5//UOp.const(dtypes.index, 64)))*UOp.const(dtypes.index, 64)))+(c1*UOp.const(dtypes.index, 81920))))
|
||||
c5 = UOp.range(UOp.const(dtypes.weakint, 2560), 0, AxisType.REDUCE)
|
||||
c6 = c4.index(((((((c5//UOp.const(dtypes.weakint, 8))%UOp.const(dtypes.weakint, 8))*UOp.const(dtypes.weakint, 8))+(c5%UOp.const(dtypes.weakint, 8)))+(((c2*UOp.const(dtypes.weakint, 40))+(c5//UOp.const(dtypes.weakint, 64)))*UOp.const(dtypes.weakint, 64)))+(c1*UOp.const(dtypes.weakint, 81920))))
|
||||
c7 = UOp(Ops.PARAM, dtypes.float.ptr(64), arg=2, src=())
|
||||
c8 = c7.index(c3)
|
||||
c9 = ((((c6+(c8*UOp.const(dtypes.float, -1.0)))*(c6+(c8*UOp.const(dtypes.float, -1.0)))).reduce(c5, arg=Ops.ADD)*UOp.const(dtypes.float, 0.000390625))+UOp.const(dtypes.float, 1e-05)).sqrt().reciprocal()
|
||||
|
||||
@@ -15,11 +15,11 @@ def get_gated_load_uop(valid:UOp, idx:UOp):
|
||||
|
||||
def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]):
|
||||
return UOp(Ops.LOAD, dtypes.float.vec(4), (
|
||||
UOp(Ops.PARAM, dtypes.imagef(image_shape), arg=0).index(UOp(Ops.VECTORIZE, dtypes.index.vec(2), idx).valid(valid), ptr=True),
|
||||
UOp(Ops.PARAM, dtypes.imagef(image_shape), arg=0).index(UOp(Ops.VECTORIZE, dtypes.weakint.vec(2), idx).valid(valid), ptr=True),
|
||||
UOp(Ops.VECTORIZE, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4)
|
||||
))
|
||||
|
||||
def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, nmax),), expr)
|
||||
def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.weakint, (UOp.const(dtypes.weakint, nmax),), expr)
|
||||
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
|
||||
def Range(n, nmax): return UOp.range(nmax, n)
|
||||
|
||||
@@ -478,7 +478,7 @@ class TestDropTrueGate(unittest.TestCase):
|
||||
from tinygrad.codegen.late.devectorizer import load_store_indexing
|
||||
from tinygrad.uop.ops import graph_rewrite
|
||||
buf = UOp(Ops.PARAM, dtypes.int.ptr(), arg=0)
|
||||
idx = UOp.const(dtypes.index, 0)
|
||||
idx = UOp.const(dtypes.weakint, 0)
|
||||
true_gate = UOp.const(dtypes.bool, True)
|
||||
index_with_gate = UOp(Ops.INDEX, dtypes.int.ptr(), (buf, idx, true_gate))
|
||||
# apply the optimization
|
||||
@@ -495,7 +495,7 @@ class TestRangeShrink(unittest.TestCase):
|
||||
def test_range_shrink_single_guard(self):
|
||||
# range 0..203 guarded by r < 4 everywhere -> shrink to 0..3
|
||||
r = Range(0, 204)
|
||||
load = get_gated_load_uop(r < UOp.const(dtypes.index, 4), r)
|
||||
load = get_gated_load_uop(r < UOp.const(dtypes.weakint, 4), r)
|
||||
ranges = self.get_ranges(load.sink())
|
||||
self.assertEqual(len(ranges), 1)
|
||||
self.assertEqual(ranges[0].src[0].arg, 4)
|
||||
@@ -503,8 +503,8 @@ class TestRangeShrink(unittest.TestCase):
|
||||
def test_range_shrink_picks_max_guard(self):
|
||||
# two loads guard the same range with r < 4 and r < 8 -> shrink to max(4, 8) = 8
|
||||
r = Range(0, 204)
|
||||
load1 = get_gated_load_uop(r < UOp.const(dtypes.index, 4), r)
|
||||
load2 = get_gated_load_uop(r < UOp.const(dtypes.index, 8), r)
|
||||
load1 = get_gated_load_uop(r < UOp.const(dtypes.weakint, 4), r)
|
||||
load2 = get_gated_load_uop(r < UOp.const(dtypes.weakint, 8), r)
|
||||
ranges = self.get_ranges(UOp.sink(load1, load2))
|
||||
self.assertEqual(len(ranges), 1)
|
||||
self.assertEqual(ranges[0].src[0].arg, 8)
|
||||
@@ -512,7 +512,7 @@ class TestRangeShrink(unittest.TestCase):
|
||||
def test_range_no_shrink_guard_ge_max(self):
|
||||
# guard r < 300 with range max 204 -> no shrink (guard doesn't constrain)
|
||||
r = Range(0, 204)
|
||||
load = get_gated_load_uop(r < UOp.const(dtypes.index, 300), r)
|
||||
load = get_gated_load_uop(r < UOp.const(dtypes.weakint, 300), r)
|
||||
ranges = self.get_ranges(load.sink())
|
||||
self.assertEqual(len(ranges), 1)
|
||||
self.assertEqual(ranges[0].src[0].arg, 204)
|
||||
@@ -520,7 +520,7 @@ class TestRangeShrink(unittest.TestCase):
|
||||
def test_range_no_shrink_when_unguarded_elsewhere(self):
|
||||
# one load guards r < 4, but another load uses r without a gate -> no shrink
|
||||
r = Range(0, 204)
|
||||
load1 = get_gated_load_uop(r < UOp.const(dtypes.index, 4), r)
|
||||
load1 = get_gated_load_uop(r < UOp.const(dtypes.weakint, 4), r)
|
||||
load2 = UOp(Ops.LOAD, dtypes.float, (UOp(Ops.PARAM, dtypes.float.ptr(), arg=1).index(r, ptr=True),))
|
||||
ranges = self.get_ranges(UOp.sink(load1, load2))
|
||||
self.assertEqual(len(ranges), 1)
|
||||
@@ -529,7 +529,7 @@ class TestRangeShrink(unittest.TestCase):
|
||||
def test_range_no_shrink_when_used_in_reduce(self):
|
||||
# range used in both a gated load AND directly in the reduce expression -> no shrink
|
||||
r = Range(0, 204)
|
||||
gated_load = get_gated_load_uop(r < UOp.const(dtypes.index, 4), r)
|
||||
gated_load = get_gated_load_uop(r < UOp.const(dtypes.weakint, 4), r)
|
||||
red = UOp(Ops.REDUCE, dtypes.float, (r.cast(dtypes.float) + gated_load, r), Ops.ADD)
|
||||
ranges = self.get_ranges(red.sink())
|
||||
self.assertEqual(len(ranges), 1)
|
||||
@@ -538,7 +538,7 @@ class TestRangeShrink(unittest.TestCase):
|
||||
def test_range_shrink_to_single_iteration(self):
|
||||
# guard r < 1 shrinks range to 1 -> single iteration, range eliminated entirely
|
||||
r = Range(0, 204)
|
||||
load = get_gated_load_uop(r < UOp.const(dtypes.index, 1), r)
|
||||
load = get_gated_load_uop(r < UOp.const(dtypes.weakint, 1), r)
|
||||
ranges = self.get_ranges(load.sink())
|
||||
self.assertEqual(len(ranges), 0)
|
||||
|
||||
|
||||
@@ -205,7 +205,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
|
||||
def test_where_same_fold(self):
|
||||
v = UOp.variable('tmp', 0, 1)
|
||||
c0 = UOp.const(dtypes.index, 0)
|
||||
c0 = UOp.const(dtypes.weakint, 0)
|
||||
vc = UOp(Ops.CMPNE, dtypes.bool, (v, c0))
|
||||
c1 = UOp.const(dtypes.float, 1.0)
|
||||
out = UOp(Ops.WHERE, dtypes.float, (vc, c1, c1))
|
||||
@@ -469,16 +469,16 @@ class TestUOpGraph(unittest.TestCase):
|
||||
# mnist indexing with split reduceop
|
||||
# Make sure we are not doign math on the loaded index, which would promote it to long
|
||||
c0 = UOp(Ops.PARAM, dtypes.uchar.ptr(128000), arg=0, src=())
|
||||
c1 = UOp.range(UOp.const(dtypes.index, 512), 1, AxisType.LOOP)
|
||||
c2 = UOp.range(UOp.const(dtypes.index, 250), 2, AxisType.LOOP)
|
||||
c1 = UOp.range(UOp.const(dtypes.weakint, 512), 1, AxisType.LOOP)
|
||||
c2 = UOp.range(UOp.const(dtypes.weakint, 250), 2, AxisType.LOOP)
|
||||
c3 = UOp(Ops.PARAM, dtypes.int.ptr(512), arg=1, src=())
|
||||
c4 = c3.index(c1)
|
||||
c5 = UOp.range(UOp.const(dtypes.index, 240), 0, AxisType.REDUCE)
|
||||
c6 = ((c2*UOp.const(dtypes.index, 240))+c5)
|
||||
c5 = UOp.range(UOp.const(dtypes.weakint, 240), 0, AxisType.REDUCE)
|
||||
c6 = ((c2*UOp.const(dtypes.weakint, 240))+c5)
|
||||
c7 = UOp(Ops.PARAM, dtypes.uchar.ptr(60000), arg=2, src=())
|
||||
c8 = c7.index(c6)
|
||||
c9 = ((c4<0).where((c4+60000), c4)!=c6.cast(dtypes.int)).where(0, c8.cast(dtypes.uint).cast(dtypes.uchar)).reduce(c5, arg=Ops.ADD)
|
||||
c10 = c0.index(((c1*UOp.const(dtypes.index, 250))+c2)).store(c9).end(c1, c2)
|
||||
c10 = c0.index(((c1*UOp.const(dtypes.weakint, 250))+c2)).store(c9).end(c1, c2)
|
||||
uops = to_uops_list([c10])
|
||||
for u in uops:
|
||||
self.assertNotEqual(u.dtype, dtypes.long)
|
||||
@@ -486,19 +486,19 @@ class TestUOpGraph(unittest.TestCase):
|
||||
def test_load_idx_no_math_on_loaded(self):
|
||||
# test the (x+y)<c pattern where x has loads - we shouldn't do math on loaded indices
|
||||
c0 = UOp(Ops.PARAM, dtypes.uchar.ptr(128000), arg=0, src=())
|
||||
c1 = UOp.range(UOp.const(dtypes.index, 512), 1, AxisType.LOOP)
|
||||
c2 = UOp.range(UOp.const(dtypes.index, 250), 2, AxisType.LOOP)
|
||||
c1 = UOp.range(UOp.const(dtypes.weakint, 512), 1, AxisType.LOOP)
|
||||
c2 = UOp.range(UOp.const(dtypes.weakint, 250), 2, AxisType.LOOP)
|
||||
c3 = UOp(Ops.PARAM, dtypes.int.ptr(512), arg=1, src=())
|
||||
c4 = c3.index(c1) # c4 is a load
|
||||
c5 = UOp.range(UOp.const(dtypes.index, 240), 0, AxisType.REDUCE)
|
||||
c6 = ((c2*UOp.const(dtypes.index, 240))+c5)
|
||||
c5 = UOp.range(UOp.const(dtypes.weakint, 240), 0, AxisType.REDUCE)
|
||||
c6 = ((c2*UOp.const(dtypes.weakint, 240))+c5)
|
||||
c7 = UOp(Ops.PARAM, dtypes.uchar.ptr(60000), arg=2, src=())
|
||||
c8 = c7.index(c6)
|
||||
# (loaded + range) < const pattern - loaded value shouldn't be promoted to long
|
||||
loaded_idx = c4.cast(dtypes.index)
|
||||
comparison = (loaded_idx + c5) < UOp.const(dtypes.index, 60000)
|
||||
loaded_idx = c4.cast(dtypes.weakint)
|
||||
comparison = (loaded_idx + c5) < UOp.const(dtypes.weakint, 60000)
|
||||
c9 = comparison.where(c8.cast(dtypes.uint).cast(dtypes.uchar), 0).reduce(c5, arg=Ops.ADD)
|
||||
c10 = c0.index(((c1*UOp.const(dtypes.index, 250))+c2)).store(c9).end(c1, c2)
|
||||
c10 = c0.index(((c1*UOp.const(dtypes.weakint, 250))+c2)).store(c9).end(c1, c2)
|
||||
uops = to_uops_list([c10])
|
||||
for u in uops:
|
||||
self.assertNotEqual(u.dtype, dtypes.long)
|
||||
|
||||
@@ -12,13 +12,13 @@ from tinygrad.uop.validate import uops_to_z3
|
||||
def check_uop_against_string(self, v:UOp, s:str):
|
||||
sym_vars = {v.render():v for v in v.toposort() if v.op in (Ops.DEFINE_VAR, Ops.RANGE, Ops.SPECIAL)}
|
||||
s_eval = eval(s, sym_vars)
|
||||
if isinstance(s_eval, int) and v.dtype==dtypes.index: s_eval = UOp.const(dtypes.index, s_eval)
|
||||
if isinstance(s_eval, int) and v.dtype==dtypes.weakint: s_eval = UOp.const(dtypes.weakint, s_eval)
|
||||
elif isinstance(s_eval, (bool, int, float)): s_eval = UOp.const(dtypes.from_py(s_eval), s_eval)
|
||||
s_eval = graph_rewrite(s_eval, commutative, name="cannonicalize eval")
|
||||
self.assertIs(s_eval, v, f"eval did not match simplified: {s_eval} != {v.render()} for {s}")
|
||||
|
||||
def Variable(name: str, min_val: ConstType, max_val: ConstType, dtype: DType=dtypes.index): return UOp.variable(name,min_val,max_val,dtype)
|
||||
def uconst(val): return UOp.const(dtypes.index, val)
|
||||
def Variable(name: str, min_val: ConstType, max_val: ConstType, dtype: DType=dtypes.weakint): return UOp.variable(name,min_val,max_val,dtype)
|
||||
def uconst(val): return UOp.const(dtypes.weakint, val)
|
||||
def usum(ops): return functools.reduce(lambda x,y: x+y, ops)
|
||||
def uand(ops): return functools.reduce(lambda x,y: x*y, ops)
|
||||
|
||||
@@ -245,12 +245,12 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.assertEqual((Variable("x", -10, 0)%Variable("y", 1, 10))._min_max, (-9, 0))
|
||||
|
||||
def test_range_div_its_symbolic_bound(self):
|
||||
a = Variable("a", 1, 10, dtypes.index)
|
||||
a = Variable("a", 1, 10, dtypes.weakint)
|
||||
ridx0 = UOp.range(a+2, 0)
|
||||
self.helper_test_variable(ridx0//(a+2), 0, 0, "0")
|
||||
|
||||
def test_range_mod_its_symbolic_bound(self):
|
||||
a = Variable("a", 1, 10, dtypes.index)
|
||||
a = Variable("a", 1, 10, dtypes.weakint)
|
||||
ridx = UOp.range(a+2, 0)
|
||||
self.helper_test_variable(ridx%(a+2), 0, 11, "r0")
|
||||
|
||||
@@ -941,7 +941,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable((numerator//denominator)<=0, 1, 1, "True")
|
||||
|
||||
def test_symbolic_range_doesnt_collapse(self):
|
||||
r0 = UOp.range((Variable("a", 1, 10)<5).cast(dtypes.index), 0)
|
||||
r0 = UOp.range((Variable("a", 1, 10)<5).cast(dtypes.weakint), 0)
|
||||
self.helper_test_variable(r0, 0, 0, "r0")
|
||||
|
||||
def test_const_reciprocal(self):
|
||||
@@ -1202,16 +1202,16 @@ class TestInvalidIndex(unittest.TestCase):
|
||||
self.assertIs((UOp.invalid()<Variable("a",0,10)).simplify().dtype, dtypes.bool)
|
||||
|
||||
def test_alu_invalid_vconst(self):
|
||||
c1 = UOp.const(dtypes.index.vec(4), (1, 1, Invalid, Invalid))
|
||||
c2 = UOp.const(dtypes.index.vec(4), (1, Invalid, 1, 1))
|
||||
self.assertIs((c1+c2).simplify(), UOp.const(dtypes.index.vec(4), (2, Invalid, Invalid, Invalid)))
|
||||
c1 = UOp.const(dtypes.weakint.vec(4), (1, 1, Invalid, Invalid))
|
||||
c2 = UOp.const(dtypes.weakint.vec(4), (1, Invalid, 1, 1))
|
||||
self.assertIs((c1+c2).simplify(), UOp.const(dtypes.weakint.vec(4), (2, Invalid, Invalid, Invalid)))
|
||||
|
||||
class TestStoreLoadFolding(unittest.TestCase):
|
||||
"""Tests for store(index, load(index)) -> NOOP rule. This rule matches patterns that EMERGE during simplification."""
|
||||
def test_store_load_folding(self):
|
||||
# store(idx, load(idx)) -> NOOP, including emergent patterns like store(idx, load(idx) + 0)
|
||||
buf = UOp(Ops.PARAM, dtypes.int.ptr(), arg=0)
|
||||
index = buf.index(UOp.const(dtypes.index, 0))
|
||||
index = buf.index(UOp.const(dtypes.weakint, 0))
|
||||
# Direct: store(idx, load(idx)) -> NOOP
|
||||
self.assertEqual(graph_rewrite(index.store(index.load()), sym).op, Ops.NOOP)
|
||||
# Emergent: store(idx, load(idx) + 0) -> store(idx, load(idx)) -> NOOP
|
||||
@@ -1250,10 +1250,10 @@ class TestGatedUopGivenValid(unittest.TestCase):
|
||||
|
||||
idx0 = (r0 + uconst(-1)) // uconst(3)
|
||||
idx1 = r0 % uconst(3)
|
||||
idx:UOp = (r0 < 3).where(UOp(Ops.VECTORIZE, dtypes.index.vec(2), (idx0, idx1)), UOp.invalid())
|
||||
idx:UOp = (r0 < 3).where(UOp(Ops.VECTORIZE, dtypes.weakint.vec(2), (idx0, idx1)), UOp.invalid())
|
||||
idx = graph_rewrite(idx, pm_simplify_valid)
|
||||
# NOTE: independent simplification: (r0-1)//3 -> 0, r0%3 -> r0 when r0 in [0,2]
|
||||
expected_vec = UOp(Ops.VECTORIZE, dtypes.index.vec(2), (uconst(0), r0))
|
||||
expected_vec = UOp(Ops.VECTORIZE, dtypes.weakint.vec(2), (uconst(0), r0))
|
||||
self.assertEqual(idx, (r0 < 3).where(expected_vec, UOp.invalid()))
|
||||
|
||||
class TestRangeSplitting(unittest.TestCase):
|
||||
|
||||
@@ -144,7 +144,7 @@ class TestVminVmaxProperties(unittest.TestCase):
|
||||
self.assertNotEqual(i.vmin, i.vmax)
|
||||
|
||||
def test_vmin_vmax_invalid_vconst(self):
|
||||
x = UOp.const(dtypes.index.vec(4), (0, 4, Invalid, Invalid))
|
||||
x = UOp.const(dtypes.weakint.vec(4), (0, 4, Invalid, Invalid))
|
||||
self.assertLess(x.vmin, 0)
|
||||
self.assertGreater(x.vmax, 4)
|
||||
|
||||
|
||||
@@ -126,7 +126,7 @@ class TestValidateOOB(unittest.TestCase):
|
||||
buf0 = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
|
||||
buf1 = UOp(Ops.PARAM, dtypes.int.ptr(64), (), 1)
|
||||
r = UOp.range(42, 0, AxisType.GLOBAL)
|
||||
ld0 = buf0.index(r.valid(r < 8), ptr=True).load(dtype=dtypes.int).cast(dtypes.index)
|
||||
ld0 = buf0.index(r.valid(r < 8), ptr=True).load(dtype=dtypes.int).cast(dtypes.weakint)
|
||||
to_uops_list([buf1.index((ld0 * 2).valid((ld0 >= 0) & (ld0 < 32)), ptr=True).load(dtype=dtypes.int)]) # valid
|
||||
with self.assertRaises(RuntimeError):
|
||||
to_uops_list([buf1.index((ld0 * 2).valid((ld0 >= 0) & (ld0 < 64)), ptr=True).load(dtype=dtypes.int)]) # oob
|
||||
@@ -135,7 +135,7 @@ class TestValidateOOB(unittest.TestCase):
|
||||
with Context(CHECK_OOB=1, SPEC=2):
|
||||
buf_bool = UOp(Ops.PARAM, dtypes.bool.ptr(16), (), 0)
|
||||
buf_int = UOp(Ops.PARAM, dtypes.int.ptr(8), (), 1)
|
||||
gidx = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 16),), "gidx0")
|
||||
gidx = UOp(Ops.SPECIAL, dtypes.weakint, (UOp.const(dtypes.weakint, 16),), "gidx0")
|
||||
ld_bool = buf_bool.index(gidx, ptr=True).load()
|
||||
with self.assertRaises(RuntimeError):
|
||||
to_uops_list([buf_int.index(gidx.valid(ld_bool), ptr=True).load()]) # gidx 0..15, buf_int size 8
|
||||
|
||||
@@ -35,7 +35,7 @@ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|No
|
||||
if len(limited) > len(max_sizes): raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
|
||||
# try to split up dims: (a,) -> (b, c)
|
||||
if limited == dims: limited = _split_dims(dims, max_sizes)
|
||||
raw_idxs = [UOp(Ops.SPECIAL, dtypes.index, (sint_to_uop(s),), (f"{prefix}{i}")) for i,s in enumerate(limited)]
|
||||
raw_idxs = [UOp(Ops.SPECIAL, dtypes.weakint, (sint_to_uop(s),), (f"{prefix}{i}")) for i,s in enumerate(limited)]
|
||||
if len(limited) < len(dims):
|
||||
ret = []
|
||||
if (contraction:=get_contraction(dims, limited)) is None: raise RuntimeError(f"get_contraction should not be None {dims=} {limited=}")
|
||||
@@ -75,7 +75,7 @@ def add_gpudims(ctx:Renderer, s:UOp):
|
||||
|
||||
# get the idxs
|
||||
ki: KernelInfo = s.arg
|
||||
if ctx.has_threads: idxs = [UOp.variable("core_id", 0, int(global_shape[0])-1, dtypes.int).cast(dtypes.index)]
|
||||
if ctx.has_threads: idxs = [UOp.variable("core_id", 0, int(global_shape[0])-1, dtypes.int).cast(dtypes.weakint)]
|
||||
elif ki.dont_use_locals:
|
||||
assert not local_dims, "can't use locals if there's no local dims"
|
||||
idxs = get_grouped_dims("idx", global_shape, ctx.global_max, reverse=True)
|
||||
|
||||
@@ -197,7 +197,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
||||
return UOp(Ops.VCAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp.group(*ret)
|
||||
|
||||
def get_image_idx(idx:UOp, width:int):
|
||||
oidx = UOp(Ops.VECTORIZE, dtypes.index.vec(2), (((x:=idx.src[1].get_idx()) // 4) % width, (x // (4*width))))
|
||||
oidx = UOp(Ops.VECTORIZE, dtypes.weakint.vec(2), (((x:=idx.src[1].get_idx()) // 4) % width, (x // (4*width))))
|
||||
return idx.replace(src=(idx.src[0], oidx.valid(idx.src[1].get_valid())))
|
||||
|
||||
def image_fixup(ls:UOp):
|
||||
@@ -208,7 +208,7 @@ def image_fixup(ls:UOp):
|
||||
return ls.replace(src=(idx,)+ls.src[1:])
|
||||
|
||||
# this is an unprocessed image without a cast, aka unfoldable image load. this doesn't work for stores
|
||||
if isinstance(image_dtype:=ls.src[0].dtype, ImageDType) and ls.src[0].src[1].get_idx().dtype != dtypes.index.vec(2):
|
||||
if isinstance(image_dtype:=ls.src[0].dtype, ImageDType) and ls.src[0].src[1].get_idx().dtype != dtypes.weakint.vec(2):
|
||||
assert ls.op is Ops.LOAD, "if an image store isn't upcasted to 4, we can't store it"
|
||||
x, idx = ls.src[0].src[1].get_idx(), get_image_idx(ls.src[0], image_dtype.shape[1])
|
||||
vec_load = ls.replace(dtype=ls.dtype.vec(4), src=(idx,)+ls.src[1:])
|
||||
@@ -263,7 +263,7 @@ def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp, bcast:UOp|None=None):
|
||||
# simple scalar index: one lane, all components
|
||||
pairs = [(0, c) for c in range(cnt)]
|
||||
idx_lanes, offsets = (tuple(x) for x in zip(*pairs))
|
||||
return buf.broadcast(len(pairs)).index(idx.gep(idx_lanes)*cnt + UOp.const(dtypes.index.vec(len(pairs)), offsets), ptr=True)
|
||||
return buf.broadcast(len(pairs)).index(idx.gep(idx_lanes)*cnt + UOp.const(dtypes.weakint.vec(len(pairs)), offsets), ptr=True)
|
||||
|
||||
devectorize_buf_and_index = PatternMatcher([
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf),
|
||||
|
||||
@@ -157,5 +157,5 @@ def no_load(u:UOp) -> bool: return not any(x.op is Ops.INDEX for x in u.backward
|
||||
pm_load_collapse = PatternMatcher([
|
||||
(UPat(Ops.REDUCE, arg=Ops.ADD, src=(UPat.var("u"), UPat()), name="red"), reduce_load_collapse),
|
||||
# we want to make sure we dont do math on a loaded index since that can cause overflow, this undoes the rule in pm_reduce_load_collapse
|
||||
((UPat.var("x", dtypes.index)+UPat.var("y"))<UPat.var("c"), lambda x,y,c: x < c-y if no_load(y) and no_load(c) and not no_load(x) else None),
|
||||
((UPat.var("x", dtypes.weakint)+UPat.var("y"))<UPat.var("c"), lambda x,y,c: x < c-y if no_load(y) and no_load(c) and not no_load(x) else None),
|
||||
])
|
||||
|
||||
@@ -331,7 +331,7 @@ class Compiled:
|
||||
# TODO: move this to each Device
|
||||
# this only tracks if the dtype is natively supported, it may be supported in the frontend using decomps
|
||||
def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
|
||||
if dtype == dtypes.index: return False
|
||||
if dtype == dtypes.weakint: return False
|
||||
if device is None: device = Device.DEFAULT
|
||||
if dtype == dtypes.bfloat16:
|
||||
if device == "METAL": return not CI
|
||||
|
||||
@@ -159,7 +159,7 @@ class dtypes:
|
||||
def is_float(x: DType) -> bool: return x.scalar() in dtypes.floats or isinstance(x, ImageDType)
|
||||
@staticmethod # static methods on top, or bool in the type info will refer to dtypes.bool
|
||||
@functools.cache
|
||||
def is_int(x: DType) -> bool: return x.scalar() in (dtypes.ints + (dtypes.index,))
|
||||
def is_int(x: DType) -> bool: return x.scalar() in (dtypes.ints + (dtypes.weakint,))
|
||||
@staticmethod
|
||||
@functools.cache
|
||||
def is_unsigned(x: DType) -> bool: return x.scalar() in dtypes.uints
|
||||
@@ -180,7 +180,7 @@ class dtypes:
|
||||
return {dtypes.float16: (5, 10), dtypes.bfloat16: (8, 7), dtypes.float32: (8, 23), dtypes.float64: (11, 52),
|
||||
dtypes.fp8e4m3: (4, 3), dtypes.fp8e5m2: (5, 2), dtypes.fp8e4m3fnuz: (4, 3), dtypes.fp8e5m2fnuz: (5, 2)}[dtype]
|
||||
void: Final[DType] = DType.new(-1, 0, "void", None)
|
||||
index: Final[DType] = DType.new(-1, 800, "index", None)
|
||||
weakint: Final[DType] = DType.new(-1, 800, "weakint", None)
|
||||
bool: Final[DType] = DType.new(0, 1, "bool", '?')
|
||||
int8: Final[DType] = DType.new(1, 8, "signed char", 'b')
|
||||
uint8: Final[DType] = DType.new(2, 8, "unsigned char", 'B')
|
||||
@@ -227,7 +227,7 @@ class dtypes:
|
||||
uints = (uint8, uint16, uint32, uint64)
|
||||
sints = (int8, int16, int32, int64)
|
||||
ints = uints + sints
|
||||
all = floats + ints + (bool, index) # noqa: A003
|
||||
all = floats + ints + (bool, weakint) # noqa: A003
|
||||
|
||||
if (env_default_float := getenv("DEFAULT_FLOAT", "")):
|
||||
dtypes.default_float = getattr(dtypes, env_default_float.lower())
|
||||
@@ -237,7 +237,8 @@ DTypeLike = str|DType
|
||||
def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) else getattr(dtypes, dtype.lower())
|
||||
|
||||
# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
|
||||
# we don't support weak type and complex type
|
||||
# we don't support complex type
|
||||
# TODO: weakint and weakfloat in lattice
|
||||
promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64],
|
||||
dtypes.int64: [dtypes.uint64], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32],
|
||||
dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.fp8e4m3, dtypes.fp8e5m2, dtypes.fp8e4m3fnuz, dtypes.fp8e5m2fnuz],
|
||||
@@ -254,8 +255,8 @@ def least_upper_dtype(*ds:DType) -> DType:
|
||||
if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0]
|
||||
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.default_float)
|
||||
|
||||
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void", "index", "_"))}
|
||||
INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void", "index":"index"}
|
||||
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void", "weakint", "_"))}
|
||||
INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void", "weakint":"weakint"}
|
||||
|
||||
@functools.cache
|
||||
def can_lossless_cast(dt0:DType, dt1:DType) -> bool:
|
||||
@@ -263,7 +264,7 @@ def can_lossless_cast(dt0:DType, dt1:DType) -> bool:
|
||||
# similar to https://numpy.org/doc/stable/reference/generated/numpy.can_cast.html
|
||||
if dt0 == dt1 or dt0 == dtypes.bool: return True
|
||||
match dt1:
|
||||
case dtypes.index: return dt0 in dtypes.ints
|
||||
case dtypes.weakint: return dt0 in dtypes.ints
|
||||
case dtypes.double: return dt0 in (dtypes.float, dtypes.half, dtypes.bfloat16, *dtypes.fp8s,
|
||||
dtypes.uint32, dtypes.uint16, dtypes.uint8, dtypes.int32, dtypes.int16, dtypes.int8)
|
||||
case dtypes.float: return dt0 in (dtypes.half, dtypes.bfloat16, *dtypes.fp8s, dtypes.uint16, dtypes.uint8, dtypes.int16, dtypes.int8)
|
||||
|
||||
@@ -347,12 +347,12 @@ def _embedding_bwd(grad_emb:UOp, call:UOp) -> tuple:
|
||||
# each device owns [offset, offset+local_vocab_size) of the global vocabulary
|
||||
dnum = UOp.variable("_device_num", 0, ndev-1)
|
||||
offset = dnum * local_vocab_size
|
||||
global_token_id = idx_flat[i].cast(dtypes.index)
|
||||
global_token_id = idx_flat[i].cast(dtypes.weakint)
|
||||
local_token_id = (global_token_id - offset).clip(0, grad_weight.shape[0]-1)
|
||||
in_range = (global_token_id >= offset) & (global_token_id < (offset + local_vocab_size))
|
||||
grad_val = in_range.where(grad_emb_flat[i, j].cast(dtypes.float), 0.0)
|
||||
else:
|
||||
local_token_id = idx_flat[i].clip(0, grad_weight.shape[0]-1).cast(dtypes.index)
|
||||
local_token_id = idx_flat[i].clip(0, grad_weight.shape[0]-1).cast(dtypes.weakint)
|
||||
grad_val = grad_emb_flat[i, j].cast(dtypes.float)
|
||||
# atomic scatter-add: grad_weight[token_id, j] += grad_emb_flat[i, j]
|
||||
if device in ("CPU", "NULL"): atomic_arg = "__atomic_fetch_add({0}, {1}, __ATOMIC_RELAXED);"
|
||||
|
||||
@@ -52,8 +52,8 @@ class IndexingContext:
|
||||
range_idx: Iterator[int] = field(default_factory=itertools.count)
|
||||
def new_range(self, s:sint, axistype:AxisType=AxisType.LOOP) -> UOp:
|
||||
if isinstance(s, UOp) and s.op is Ops.RANGE: return s
|
||||
# if a range has a 1 src, it's the same as UOp.const(dtypes.index, 0)
|
||||
return UOp.range(s, next(self.range_idx), axistype) if resolve(s!=1) else UOp.const(dtypes.index, 0)
|
||||
# if a range has a 1 src, it's the same as UOp.const(dtypes.weakint, 0)
|
||||
return UOp.range(s, next(self.range_idx), axistype) if resolve(s!=1) else UOp.const(dtypes.weakint, 0)
|
||||
|
||||
def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
||||
if x.op in {Ops.BUFFERIZE, Ops.INDEX}: return None
|
||||
@@ -118,7 +118,7 @@ def _apply_reshape(in_shape:tuple[sint,...], out_shape:tuple[sint, ...], urngs:U
|
||||
for s,src in list(zip(out_shape, urngs.src))[::-1]:
|
||||
axes_in.append(acc*src)
|
||||
acc *= s
|
||||
combined_axes = UOp.const(dtypes.index, 0).sum(*axes_in)
|
||||
combined_axes = UOp.const(dtypes.weakint, 0).sum(*axes_in)
|
||||
axes_out:list[UOp] = []
|
||||
for s in in_shape[::-1]:
|
||||
axes_out.append(combined_axes % s)
|
||||
@@ -172,7 +172,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||
# treat MSTACK/MSELECT like SINK
|
||||
if x.op in {Ops.MSTACK, Ops.MSELECT}: continue
|
||||
|
||||
if x.dtype.scalar() == dtypes.index: continue # TODO: why do I need this?
|
||||
if x.dtype.scalar() == dtypes.weakint: continue # TODO: why do I need this?
|
||||
ending_ranges[x] = sum([ending_ranges.get(u, []) for u in consumer_map[x]], [])
|
||||
|
||||
# *** the ranges on the output are
|
||||
|
||||
@@ -94,7 +94,7 @@ def split_reduceop(reduce:UOp, x:UOp):
|
||||
# split is moved to the end to provide maximum locality for the second phase reduce.
|
||||
|
||||
# get expanded by rangeifying the UOp x
|
||||
indexed = x.index(*[UOp.range(s, i) if resolve(s>1) else UOp.const(dtypes.index, 0) for i,s in enumerate(x.shape)])
|
||||
indexed = x.index(*[UOp.range(s, i) if resolve(s>1) else UOp.const(dtypes.weakint, 0) for i,s in enumerate(x.shape)])
|
||||
range_nums = [y.arg[0] for y in indexed.substitute({x.base:UOp(Ops.NOOP)}, extra_pm=pm_mops).ranges]
|
||||
is_expanded = [i not in range_nums for i in range(len(x.shape))]
|
||||
|
||||
|
||||
@@ -132,9 +132,10 @@ class Tensor(OpMixin):
|
||||
|
||||
# create a UOp from the different types of inputs
|
||||
if isinstance(data, UOp):
|
||||
assert _dtype is None or _dtype==data.dtype or data.dtype==dtypes.index, f"dtype mismatch: {_dtype} vs {data.dtype}"
|
||||
# if data is dtype.index that means that this is a symbolic int and we need to lower it to something we can make a Tensor out of
|
||||
if data.dtype==dtypes.index: data = _index_to_concrete_int(data)
|
||||
assert _dtype is None or _dtype==data.dtype or data.dtype==dtypes.weakint, f"dtype mismatch: {_dtype} vs {data.dtype}"
|
||||
# if data is dtype.weakint that means that this is a symbolic int and we need to lower it to something we can make a Tensor out of
|
||||
# TODO: remove this and stay in weakint
|
||||
if data.dtype==dtypes.weakint: data = _index_to_concrete_int(data)
|
||||
if data.op is Ops.BIND:
|
||||
var, val = data.unbind()
|
||||
# give the bound constant a device
|
||||
|
||||
@@ -107,17 +107,17 @@ div_and_mod_symbolic = PatternMatcher([
|
||||
# ** 1. Fast Inline Rules **
|
||||
((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)
|
||||
if c.vmin>0 and d.vmin>0 and x.vmin>=0 and a.vmin>=0 else None), # (x//c+a)//d -> (x+a*c)//(c*d)
|
||||
(UPat.var("x", dtypes.index) // UPat.var("d"), lambda x,d: -(x//(-d)) if d.vmax < 0 else None),
|
||||
(UPat.var("x", dtypes.index) // UPat.var("d"), lambda x,d: -((-x)//d) if x.vmax <= 0 else None),
|
||||
((UPat.var("x", dtypes.index)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False),
|
||||
(UPat.var("x", dtypes.weakint) // UPat.var("d"), lambda x,d: -(x//(-d)) if d.vmax < 0 else None),
|
||||
(UPat.var("x", dtypes.weakint) // UPat.var("d"), lambda x,d: -((-x)//d) if x.vmax <= 0 else None),
|
||||
((UPat.var("x", dtypes.weakint)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False),
|
||||
lambda x,c,n,d: ((x+c.arg%d.arg)//d + c.arg//d.arg) if c.arg%d.arg!=c.arg and x.vmin>=0 and n.vmin>=0 and d.arg>0 else None),
|
||||
((UPat.var("x", dtypes.index)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False),
|
||||
((UPat.var("x", dtypes.weakint)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False),
|
||||
lambda x,c,n,d: (-(-(c.arg%d.arg + x - (d.arg-1))//d) + c.arg//d.arg) if x.vmax<=0 and n.vmin>=0 and d.arg>0 else None),
|
||||
|
||||
# ** 2. Slow Rules **
|
||||
(UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d"), lambda d: fold_divmod_general(d, bool(CORRECT_DIVMOD_FOLDING))),
|
||||
(UPat((Ops.IDIV, Ops.MOD), dtypes.weakint, name="d"), lambda d: fold_divmod_general(d, bool(CORRECT_DIVMOD_FOLDING))),
|
||||
|
||||
# NOTE: these have to go at the bottom or TestSymbolicOps.test_var loops
|
||||
(UPat.var("x", dtypes.index) % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <= 0 else None),
|
||||
(UPat.var("x", dtypes.index) % UPat.var("d"), lambda x,d: (x%(-d)) if d.vmax < 0 else None),
|
||||
(UPat.var("x", dtypes.weakint) % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <= 0 else None),
|
||||
(UPat.var("x", dtypes.weakint) % UPat.var("d"), lambda x,d: (x%(-d)) if d.vmax < 0 else None),
|
||||
])
|
||||
@@ -59,9 +59,9 @@ def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str:
|
||||
return ret
|
||||
|
||||
def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp:
|
||||
if len(arg) == 0: return UOp(Ops.VECTORIZE, dtypes.index.vec(0))
|
||||
elif all_int(arg): return UOp.const(dtypes.index.vec(len(arg)), arg)
|
||||
else: return UOp(Ops.VECTORIZE, dtypes.index.vec(len(arg)), tuple(UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in arg))
|
||||
if len(arg) == 0: return UOp(Ops.VECTORIZE, dtypes.weakint.vec(0))
|
||||
elif all_int(arg): return UOp.const(dtypes.weakint.vec(len(arg)), arg)
|
||||
else: return UOp(Ops.VECTORIZE, dtypes.weakint.vec(len(arg)), tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg))
|
||||
|
||||
def consumer_map_from_toposort(lst:Iterable[UOp]):
|
||||
ret: dict[UOp, dict[UOp, None]] = {}
|
||||
@@ -426,9 +426,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
assert len(idx) == len(self.shape), f"__getitem__ shape mismatch, indexing {self.shape} with {len(idx)} args"
|
||||
if len(slice_idx:=[i for i,x in enumerate(idx) if isinstance(x, slice)]):
|
||||
perm = self.permute(tuple([i for i in range(self.ndim) if i not in slice_idx] + slice_idx))
|
||||
return perm.index(*[UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in idx if not isinstance(x, slice)], ptr=True)
|
||||
return perm.index(*[UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in idx if not isinstance(x, slice)], ptr=True)
|
||||
else:
|
||||
return self.index(*[UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in idx])
|
||||
return self.index(*[UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in idx])
|
||||
def const_like(self, b:ConstLike):
|
||||
# constants can optionally have a DEVICE source
|
||||
return UOp.const(self.dtype, b, device=self._device, shape=self._shape)
|
||||
@@ -482,21 +482,21 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
ret = ret.replace(src=(UOp.unique(None if unique is True else unique),) + ret.src)
|
||||
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None else ret
|
||||
@staticmethod
|
||||
def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.index, src=(), **kwargs):
|
||||
def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.weakint, src=(), **kwargs):
|
||||
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(axis_id, axis_type)+arg, **kwargs)
|
||||
@staticmethod
|
||||
def special(end:sint, name:str, dtype=dtypes.index): return UOp(Ops.SPECIAL, dtype=dtype, src=(sint_to_uop(end, dtype),), arg=name)
|
||||
def special(end:sint, name:str, dtype=dtypes.weakint): return UOp(Ops.SPECIAL, dtype=dtype, src=(sint_to_uop(end, dtype),), arg=name)
|
||||
def r(self, op:Ops, axis:tuple[int, ...]):
|
||||
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
|
||||
return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis)) if len(axis) else self
|
||||
@staticmethod
|
||||
def invalid(count=1): return UOp(Ops.CONST, dtypes.index.vec(count), src=(), arg=Invalid)
|
||||
def invalid(count=1): return UOp(Ops.CONST, dtypes.weakint.vec(count), src=(), arg=Invalid)
|
||||
def valid(self, cond): return self if cond.op is Ops.WHERE and cond.arg else cond.where(self, UOp.invalid(self.dtype.count))
|
||||
def get_idx(self) -> UOp:
|
||||
assert self.dtype.scalar() is dtypes.index, "Can only call get_idx on index dtype"
|
||||
assert self.dtype.scalar() is dtypes.weakint, "Can only call get_idx on index dtype"
|
||||
return self.src[1] if self.op is Ops.WHERE and self.src[2].arg is Invalid else self
|
||||
def get_valid(self) -> UOp:
|
||||
assert self.dtype.scalar() is dtypes.index, "Can only call get_valid on index dtype"
|
||||
assert self.dtype.scalar() is dtypes.weakint, "Can only call get_valid on index dtype"
|
||||
return self.src[0] if self.op is Ops.WHERE and self.src[2].arg is Invalid else UOp.const(dtypes.bool, self.arg is not Invalid)
|
||||
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
|
||||
|
||||
@@ -757,7 +757,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
# *** uop Variable stuff ***
|
||||
|
||||
@staticmethod
|
||||
def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.index) -> UOp:
|
||||
def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.weakint) -> UOp:
|
||||
assert not isinstance(min_val, UOp) and not isinstance(max_val, UOp), f"can't create Variable {name} with {min_val}/{max_val}"
|
||||
return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
|
||||
@property
|
||||
@@ -866,7 +866,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.VCONST and Invalid not in self.arg: return (min(self.arg), max(self.arg))
|
||||
if self.op is Ops.GEP: return self.src[0]._min_max
|
||||
# TODO: CAST to bool/unsigned is not monotone, still some case can be simplified
|
||||
if self.op is Ops.CAST and self.dtype in dtypes.floats+dtypes.sints+(dtypes.index,):
|
||||
if self.op is Ops.CAST and self.dtype in dtypes.floats+dtypes.sints+(dtypes.weakint,):
|
||||
return max(self.dtype.min, self.src[0].vmin), min(self.src[0].vmax, self.dtype.max)
|
||||
return self.dtype.min, self.dtype.max
|
||||
|
||||
@@ -991,7 +991,7 @@ python_alu: dict[Ops, Callable] = {
|
||||
def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
|
||||
if dtype.count > 1:
|
||||
return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)])
|
||||
if dtype==dtypes.index and op in GroupOp.Binary and Invalid in operands: return Invalid
|
||||
if dtype==dtypes.weakint and op in GroupOp.Binary and Invalid in operands: return Invalid
|
||||
alu = python_alu[op](*operands)
|
||||
return truncate.get(dtype, lambda x: x)(alu) if truncate_output else alu
|
||||
|
||||
@@ -1424,23 +1424,26 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=N
|
||||
rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx, enter_calls)
|
||||
return rewrite_ctx.walk_rewrite(sink) if walk else rewrite_ctx.unified_rewrite(sink)
|
||||
|
||||
def sint_to_uop(x:sint, dtype=dtypes.index) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x.cast(dtype)
|
||||
def sint_to_uop(x:sint, dtype=dtypes.weakint) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x.cast(dtype)
|
||||
|
||||
def select_dtype(u): return (dtypes.long if u.overflows(dtypes.int32) else dtypes.int).vec(u.dtype.count)
|
||||
pm_lower_index_dtype = PatternMatcher([
|
||||
# There are no Unary ops at this point in symbolic, those are introduced later
|
||||
(UPat(GroupOp.Binary, name="u", src=(UPat.var("x").cast(dtypes.index), UPat.var("y").cast(dtypes.index))), lambda u,x,y:
|
||||
(UPat(GroupOp.Binary, name="u", src=(UPat.var("x").cast(dtypes.weakint), UPat.var("y").cast(dtypes.weakint))), lambda u,x,y:
|
||||
x.cast(dt:=least_upper_dtype(select_dtype(u), x.dtype, y.dtype)).alu(u.op, y.cast(dt)).cast(u.dtype)),
|
||||
(UPat((Ops.CONST, Ops.VCONST), dtype=dtypes.index, name="u"), lambda u: u.replace(dtype=select_dtype(u)).cast(u.dtype) if u.arg!=Invalid else None),
|
||||
(UPat(Ops.WHERE, dtypes.index, src=(UPat.var("cond"), UPat.var("x").cast(dtypes.index), UPat.var("y").cast(dtypes.index))), lambda cond,x,y:
|
||||
cond.where(x.cast(dt:=least_upper_dtype(x.dtype, y.dtype)), y.cast(dt)).cast(dtypes.index)),
|
||||
(UPat(Ops.RANGE, src=(UPat.var("end").cast(dtypes.index)), name="r"), lambda r,end: r.replace(dtype=end.dtype, src=(end,)).cast(dtypes.index)),
|
||||
(UPat(Ops.VECTORIZE, src=UPat().cast(dtypes.index), name="v"),
|
||||
lambda v: v.replace(dtype=(dt:=select_dtype(v)), src=tuple(s.src[0].cast(dt.scalar()) for s in v.src)).cast(dtypes.index)),
|
||||
(UPat((Ops.CONST, Ops.VCONST), dtype=dtypes.weakint, name="u"),
|
||||
lambda u: u.replace(dtype=select_dtype(u)).cast(u.dtype) if u.arg!=Invalid else None),
|
||||
(UPat(Ops.WHERE, dtypes.weakint, src=(UPat.var("cond"), UPat.var("x").cast(dtypes.weakint), UPat.var("y").cast(dtypes.weakint))), lambda cond,x,y:
|
||||
cond.where(x.cast(dt:=least_upper_dtype(x.dtype, y.dtype)), y.cast(dt)).cast(dtypes.weakint)),
|
||||
(UPat(Ops.RANGE, src=(UPat.var("end").cast(dtypes.weakint)), name="r"), lambda r,end: r.replace(dtype=end.dtype, src=(end,)).cast(dtypes.weakint)),
|
||||
(UPat(Ops.VECTORIZE, src=UPat().cast(dtypes.weakint), name="v"),
|
||||
lambda v: v.replace(dtype=(dt:=select_dtype(v)), src=tuple(s.src[0].cast(dt.scalar()) for s in v.src)).cast(dtypes.weakint)),
|
||||
# special can only be int32
|
||||
(UPat(Ops.SPECIAL, src=(UPat.var("var").cast(dtypes.index),), name="u"), lambda u,var: u.replace(dtype=dtypes.int, src=(var,)).cast(dtypes.index)),
|
||||
(UPat(Ops.DEFINE_VAR, dtype=dtypes.index, name="u"), lambda u: u.replace(dtype=dtypes.int).cast(dtypes.index)),
|
||||
(UPat(Ops.BIND, src=(UPat.var("var").cast(dtypes.index), UPat.cvar("val").cast(dtypes.index))), lambda var,val: var.bind(val).cast(dtypes.index)),
|
||||
(UPat(Ops.SPECIAL, src=(UPat.var("var").cast(dtypes.weakint),), name="u"),
|
||||
lambda u,var: u.replace(dtype=dtypes.int, src=(var,)).cast(dtypes.weakint)),
|
||||
(UPat(Ops.DEFINE_VAR, dtype=dtypes.weakint, name="u"), lambda u: u.replace(dtype=dtypes.int).cast(dtypes.weakint)),
|
||||
(UPat(Ops.BIND, src=(UPat.var("var").cast(dtypes.weakint), UPat.cvar("val").cast(dtypes.weakint))),
|
||||
lambda var,val: var.bind(val).cast(dtypes.weakint)),
|
||||
# lower Invalid
|
||||
(UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"), UPat(Ops.CONST, arg=Invalid))), lambda buf,idx,cond: buf.index(idx, cond, ptr=True)),
|
||||
# remove hanging casts
|
||||
@@ -1448,7 +1451,7 @@ pm_lower_index_dtype = PatternMatcher([
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))),
|
||||
lambda buf,idx,valid: buf.index(idx, valid, ptr=True)),
|
||||
(UPat((Ops.SINK, Ops.NOOP, Ops.END), name="n"),
|
||||
lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.index else s for s in n.src))),
|
||||
lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.weakint else s for s in n.src))),
|
||||
])
|
||||
def _index_to_concrete_int(u:UOp) -> UOp: return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]
|
||||
|
||||
@@ -1535,7 +1538,7 @@ pm_pyrender_extra = PatternMatcher([
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE, name="d"),), name="x"), lambda x,d: f"UOp.const({x.dtype}, {x.arg}, device={repr(d.arg)})"),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"),
|
||||
(UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x:
|
||||
f"UOp.variable(\"{x.arg[0]}\", {x.arg[1]}, {x.arg[2]}{', dtype='+str(x.dtype) if x.dtype is not dtypes.index else ''})"),
|
||||
f"UOp.variable(\"{x.arg[0]}\", {x.arg[1]}, {x.arg[2]}{', dtype='+str(x.dtype) if x.dtype is not dtypes.weakint else ''})"),
|
||||
(UPat((Ops.CAST, Ops.BITCAST), name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({x.dtype})"),
|
||||
(UPat(Ops.SPECIAL, src=(UPat(Ops.CONST),), name="x"), lambda x: f"UOp.special({x.src[0].arg}, {repr(x.arg)}, dtype={x.dtype})"),
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), name="x"), lambda x,u,d:
|
||||
@@ -1546,7 +1549,7 @@ pm_pyrender_extra = PatternMatcher([
|
||||
# NOTE: range has srcs sometimes after control flow
|
||||
(UPat(Ops.RANGE, src=(UPat(Ops.CONST, name="c"),), allow_any_len=True, name="x"), lambda ctx,x,c:
|
||||
"UOp.range("+', '.join([str(c.arg)] + [repr(y) for y in x.arg])+
|
||||
(f', src={srcs(ctx, x.src[1:])}' if len(x.src) > 1 else '')+(', dtype='+str(x.dtype) if x.dtype is not dtypes.index else '')+")"),
|
||||
(f', src={srcs(ctx, x.src[1:])}' if len(x.src) > 1 else '')+(', dtype='+str(x.dtype) if x.dtype is not dtypes.weakint else '')+")"),
|
||||
# TODO: index shouldn't mismatch dtype
|
||||
(UPat(Ops.INDEX, src=(UPat(), UPat()), allow_any_len=True, name="x"), lambda ctx,x:
|
||||
f"{ctx[x.src[0]]}.index({ctx[x.src[1]]}, "+''.join([f"{ctx[xx]}, " for xx in x.src[2:]])+
|
||||
|
||||
@@ -54,7 +54,7 @@ shared_spec = PatternMatcher([
|
||||
(UPat(Ops.RANGE, src=(UPat.var("x"),), allow_any_len=True, name="rng"), lambda rng,x:
|
||||
rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \
|
||||
all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)),
|
||||
(UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None),
|
||||
(UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.weakint for y in x.src[1:]) or None),
|
||||
|
||||
# RANGE/SPECIAL define loops, END closes them
|
||||
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE))), lambda: True),
|
||||
@@ -69,13 +69,13 @@ shared_spec = PatternMatcher([
|
||||
# ***** UOp spec in the Tensor graph *****
|
||||
|
||||
movement_ops = PatternMatcher([
|
||||
(UPat((Ops.RESHAPE, Ops.EXPAND), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True),
|
||||
(UPat((Ops.PAD, Ops.SHRINK), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index), UPat(dtype=dtypes.index))), lambda mv,x: True),
|
||||
(UPat((Ops.RESHAPE, Ops.EXPAND), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.weakint))), lambda mv,x: True),
|
||||
(UPat((Ops.PAD, Ops.SHRINK), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.weakint), UPat(dtype=dtypes.weakint))), lambda mv,x: True),
|
||||
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat.var("x"),)), lambda mv,x: isinstance(mv.arg, tuple)),
|
||||
|
||||
# inputs to movement ops
|
||||
(UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.index), lambda: True),
|
||||
(UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.index), lambda: True),
|
||||
(UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.weakint), lambda: True),
|
||||
(UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.weakint), lambda: True),
|
||||
|
||||
# AFTER on Movement Op, BUFFER, COPY, or BITCAST
|
||||
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI, Ops.CONTIGUOUS, Ops.BUFFER, Ops.BITCAST, Ops.COPY})),), allow_any_len=True),
|
||||
@@ -104,9 +104,9 @@ _tensor_spec = PatternMatcher([
|
||||
(UPat(Ops.MSTACK, name="x"), lambda x: all(isinstance(x.device, str) for x in x.src)),
|
||||
|
||||
# Tensor variable bindings
|
||||
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.index,))), arg=None), lambda: True),
|
||||
(UPat(Ops.BIND, (dtypes.int,dtypes.weakint,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.weakint,))), arg=None), lambda: True),
|
||||
# single-src BIND used for schedule cache key normalization
|
||||
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR),), arg=None), lambda: True),
|
||||
(UPat(Ops.BIND, (dtypes.int,dtypes.weakint,), (UPat(Ops.DEFINE_VAR),), arg=None), lambda: True),
|
||||
|
||||
# device or unique
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),)), lambda: True),
|
||||
@@ -189,7 +189,7 @@ shared_codegen_spec = PatternMatcher([
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines|{Ops.AFTER}, name="buf"), UPat.var("idx"), UPat.var("gate", dtype=dtypes.bool))), validate_index),
|
||||
|
||||
# SPECIAL
|
||||
(UPat(Ops.SPECIAL, src=(UPat.var("x", (dtypes.index, dtypes.int32)),), name="s"), lambda s,x: s.dtype == x.dtype and isinstance(s.arg, str)),
|
||||
(UPat(Ops.SPECIAL, src=(UPat.var("x", (dtypes.weakint, dtypes.int32)),), name="s"), lambda s,x: s.dtype == x.dtype and isinstance(s.arg, str)),
|
||||
|
||||
# BARRIER (on any length)
|
||||
(UPat(Ops.BARRIER, dtypes.void), lambda: True),
|
||||
@@ -199,7 +199,7 @@ shared_codegen_spec = PatternMatcher([
|
||||
|
||||
kernel_spec = PatternMatcher([
|
||||
# index is allowed here
|
||||
(UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.index), lambda: True),
|
||||
(UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.weakint), lambda: True),
|
||||
|
||||
# UNROLL/CONTRACT is used here for WMMA
|
||||
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
||||
@@ -212,7 +212,7 @@ kernel_spec = PatternMatcher([
|
||||
(UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: True),
|
||||
|
||||
# reduce must be on ranges
|
||||
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype in (dtypes.index, dtypes.int) for y in x.src[1:])),
|
||||
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype in (dtypes.weakint, dtypes.int) for y in x.src[1:])),
|
||||
|
||||
# COPY/BUFFER_VIEW can have ranges appended
|
||||
(UPat(Ops.COPY, name="x", src=(UPat.var("s"), UPat(Ops.DEVICE)), allow_any_len=True, arg=None),
|
||||
@@ -236,7 +236,7 @@ program_spec = PatternMatcher([
|
||||
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True),
|
||||
|
||||
# make sure all index dtypes have been lowered (except CONST/RANGE/DEFINE_VAR which are valid index-typed)
|
||||
(UPat(GroupOp.All-{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR, Ops.VCONST, Ops.VECTORIZE}, dtype=dtypes.index), lambda: False),
|
||||
(UPat(GroupOp.All-{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR, Ops.VCONST, Ops.VECTORIZE}, dtype=dtypes.weakint), lambda: False),
|
||||
(UPat(Ops.CONST, arg=Invalid), lambda: False),
|
||||
(UPat(Ops.VCONST, name="x"), lambda x: all(v is not Invalid for v in x.arg) and len(x.arg)==x.dtype.vcount>1 and
|
||||
type(x.arg) is type(x.dtype.const(x.arg))),
|
||||
@@ -273,13 +273,13 @@ full_spec = PatternMatcher([
|
||||
(UPat(Ops.CALL, dtype=dtypes.void), lambda: True),
|
||||
|
||||
# where on index in rhs position is fine
|
||||
(UPat(Ops.WHERE, dtype=dtypes.index, src=(UPat(dtype=dtypes.bool), UPat(), UPat(dtype=dtypes.index))), lambda: True),
|
||||
(UPat(Ops.WHERE, dtype=dtypes.weakint, src=(UPat(dtype=dtypes.bool), UPat(), UPat(dtype=dtypes.weakint))), lambda: True),
|
||||
# allow index dtype on a restricted set of UOps
|
||||
(UPat((Ops.ADD, Ops.MUL, Ops.MOD, Ops.IDIV, Ops.MAX,
|
||||
Ops.SPECIAL, Ops.CAST, Ops.RANGE, Ops.VCONST, Ops.VECTORIZE), dtype=dtypes.index), lambda: True),
|
||||
Ops.SPECIAL, Ops.CAST, Ops.RANGE, Ops.VCONST, Ops.VECTORIZE), dtype=dtypes.weakint), lambda: True),
|
||||
|
||||
# while BIND is being casted
|
||||
(UPat(Ops.BIND, (dtypes.int, dtypes.index), (UPat(), UPat()), arg=None), lambda: True),
|
||||
(UPat(Ops.BIND, (dtypes.int, dtypes.weakint), (UPat(), UPat()), arg=None), lambda: True),
|
||||
|
||||
# in progress MSTACK may lose device
|
||||
(UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True),
|
||||
|
||||
@@ -51,15 +51,15 @@ def fold_add_divmod_recombine(x:UOp) -> UOp|None:
|
||||
# this needs to be before symbolic so that 0*something_that_might_be_invalid doesnt become 0
|
||||
propagate_invalid = PatternMatcher([
|
||||
# propagate invalid, push it past children
|
||||
(invalid_gate.cast(name="cast"), lambda i,x,cond,cast: x.cast(cast.dtype) if i.dtype is dtypes.index else None),
|
||||
(invalid_gate.cast(name="cast"), lambda i,x,cond,cast: x.cast(cast.dtype) if i.dtype is dtypes.weakint else None),
|
||||
(UPat(GroupOp.Unary, src=(invalid_gate,), name="alu"), lambda cond,x,alu,i: cond.where(x.alu(alu.op), i)),
|
||||
(UPat(GroupOp.Binary-GroupOp.Comparison, src=(invalid_gate, UPat.var("y")), name="alu"), lambda cond,x,y,alu,i: cond.where(x.alu(alu.op,y), i)),
|
||||
(UPat(GroupOp.Binary-GroupOp.Comparison, src=(UPat.var("y"), invalid_gate), name="alu"), lambda cond,x,y,alu,i: cond.where(y.alu(alu.op,x), i)),
|
||||
# TODO: when can this happen? and is it always safe to just drop invalid?
|
||||
(UPat(GroupOp.Comparison, src=(invalid_gate, UPat.var("y")), name="alu"), lambda cond,x,y,alu,i:
|
||||
x.alu(alu.op,y) if i.dtype is dtypes.index else cond.where(x.alu(alu.op,y), i.cast(dtypes.bool))),
|
||||
x.alu(alu.op,y) if i.dtype is dtypes.weakint else cond.where(x.alu(alu.op,y), i.cast(dtypes.bool))),
|
||||
(UPat(GroupOp.Comparison, src=(UPat.var("y"), invalid_gate), name="alu"), lambda cond,x,y,alu,i:
|
||||
y.alu(alu.op,x) if i.dtype is dtypes.index else cond.where(y.alu(alu.op,x), i.cast(dtypes.bool))),
|
||||
y.alu(alu.op,x) if i.dtype is dtypes.weakint else cond.where(y.alu(alu.op,x), i.cast(dtypes.bool))),
|
||||
# alu with invalid -> invalid
|
||||
(UPat(GroupOp.Unary, src=(invalid_pat,)), lambda i: i),
|
||||
(UPat(GroupOp.Binary-GroupOp.Comparison, src=[invalid_pat, UPat()]), lambda i: i),
|
||||
@@ -77,26 +77,26 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
||||
# ** self folding **
|
||||
(UPat.var("x") + 0, lambda x: x), # x+0 -> x
|
||||
(UPat.var("x") * 1, lambda x: x), # x*1 -> x
|
||||
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)) ^ 0, lambda x: x), # x^0 -> x
|
||||
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)) ^ 0, lambda x: x), # x^0 -> x
|
||||
(UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1
|
||||
(UPat.var("x") // 1, lambda x: x), # x//1 -> x
|
||||
(UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x
|
||||
((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed)
|
||||
# variations of (x%c)+(x//c)*c = x
|
||||
(UPat(Ops.ADD, dtype=dtypes.index, name="x"), fold_add_divmod_recombine),
|
||||
(UPat(Ops.ADD, dtype=dtypes.weakint, name="x"), fold_add_divmod_recombine),
|
||||
(UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
|
||||
(UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
|
||||
(UPat(GroupOp.Idempotent, src=(UPat.var("x"), UPat.var("x"))), lambda x: x),
|
||||
(UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
|
||||
(UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x),
|
||||
(UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, False), UPat.const(dtypes.bool, True)), lambda x: x.logical_not()),
|
||||
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)).trunc(), lambda x: x),
|
||||
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)).trunc(), lambda x: x),
|
||||
# ** zero folding **
|
||||
(UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False
|
||||
(UPat.var("x") % UPat.var("x"), lambda x: x.const_like(0)), # x%x -> 0
|
||||
(UPat.var("x") ^ UPat.var("x"), lambda x: x.const_like(0)), # x^x -> 0
|
||||
(UPat.var("x") & 0, lambda x: x.const_like(0)), # x&0 -> 0
|
||||
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)) != UPat.var("x"),
|
||||
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)) != UPat.var("x"),
|
||||
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
|
||||
# ** constant folding **
|
||||
# TODO: add const folding for Ops.THREEFRY
|
||||
@@ -192,7 +192,7 @@ gep_pushing = PatternMatcher([
|
||||
# GEP in order is removed
|
||||
(UPat(Ops.GEP, name="g"), lambda g: g.src[0] if not isinstance(g.dtype, PtrDType) and g.arg == tuple(range(g.src[0].dtype.count)) else None),
|
||||
# push all GEPs through ALUs for index (TODO: remove this)
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu').f(Ops.GEP, dtype=dtypes.index, name='gep'),
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu').f(Ops.GEP, dtype=dtypes.weakint, name='gep'),
|
||||
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
|
||||
if not isinstance(gep.dtype, PtrDType) and not isinstance(alu.dtype, PtrDType) else None),
|
||||
# CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later)
|
||||
@@ -207,7 +207,7 @@ gep_pushing = PatternMatcher([
|
||||
commutative = PatternMatcher([
|
||||
# ** COMMUTATIVE flipping (only for index) **
|
||||
# NOTE: this can break merging vector math by only flipping some of them
|
||||
(UPat(GroupOp.Commutative, dtype=dtypes.index, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
|
||||
(UPat(GroupOp.Commutative, dtype=dtypes.weakint, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
|
||||
])
|
||||
|
||||
symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
@@ -224,7 +224,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
((UPat.var("y") + UPat.var("x")) + UPat.var("x"), lambda y,x: y+x*2),
|
||||
((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3) if x2 is not x3 else None), # (x/x2)/x3 -> x/(x2*x3)
|
||||
(-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c
|
||||
(UPat.cvar("y") * (UPat.var("x", dtype=dtypes.index) + UPat.cvar("c")), lambda x,y,c: (y*x)+(y*c)), # y*(x+c) -> y*x + y*c
|
||||
(UPat.cvar("y") * (UPat.var("x", dtype=dtypes.weakint) + UPat.cvar("c")), lambda x,y,c: (y*x)+(y*c)), # y*(x+c) -> y*x + y*c
|
||||
# ** where folding **
|
||||
(UPat.var("cond", dtype=dtypes.bool).logical_not().where(UPat.var("t"), UPat.var("f")),
|
||||
lambda cond, t, f: cond.where(f,t) if f.arg is not Invalid else None),
|
||||
@@ -249,35 +249,35 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
((UPat.var("x") // UPat.cvar("c1")) // UPat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2)
|
||||
# ** lt **
|
||||
# c0*x<c1 for positive int c0,c1
|
||||
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.index))<UPat.cvar("c1", vec=False),
|
||||
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.weakint))<UPat.cvar("c1", vec=False),
|
||||
lambda x,c0,c1: x<math.ceil(c1.arg/c0.arg) if c0.arg > 0 and c1.arg > 0 else None),
|
||||
# c0*x<c1 for negative int c0 and non-positive c1
|
||||
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.index))<UPat.cvar("c1", vec=False),
|
||||
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.weakint))<UPat.cvar("c1", vec=False),
|
||||
lambda x,c0,c1: (-x)<(-(math.floor(-c1.arg/-c0.arg))) if c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
|
||||
# x//d<c
|
||||
((UPat.var("x", dtype=dtypes.index)//UPat.cvar("d", vec=False))<UPat.cvar("c", vec=False),
|
||||
((UPat.var("x", dtype=dtypes.weakint)//UPat.cvar("d", vec=False))<UPat.cvar("c", vec=False),
|
||||
lambda x,d,c: (x<(c.arg*d.arg) if c.arg > 0 else x<(c.arg*d.arg-(d.arg-1))) if d.arg > 0 else None),
|
||||
# ** move add/mul consts to end (NOTE: this is still happening before constant folding) **
|
||||
((UPat.var("x") + UPat.cvar("c1")) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
|
||||
((UPat.var("x") * UPat.cvar("c1")) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
|
||||
# *** rules from symbolic ***
|
||||
# generic lt folding
|
||||
(UPat.var("x", dtypes.index)<UPat.cvar("c", vec=False), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
|
||||
(UPat.var("x", dtypes.index)*-1 < UPat.var("y")*-1, lambda x,y: y<x),
|
||||
(UPat.var("x", dtypes.weakint)<UPat.cvar("c", vec=False), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
|
||||
(UPat.var("x", dtypes.weakint)*-1 < UPat.var("y")*-1, lambda x,y: y<x),
|
||||
# canonicalize a simplex with positive coefficients > 0. NOTE: not x < 1 means x > 0
|
||||
((UPat.var("x", dtypes.index)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None),
|
||||
((UPat.var("x", dtypes.weakint)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None),
|
||||
# a range mod its own upper bound is just the range
|
||||
(UPat(Ops.RANGE, src=UPat.var("end"), name="r")%UPat.var("end"), lambda r,end: r),
|
||||
(UPat(Ops.RANGE, src=UPat.var("end"), name="r")//UPat.var("end"), lambda r,end: r.const_like(0)),
|
||||
# cast/long folding
|
||||
# if the intermediate cast doesnt narrow we can do it in one cast
|
||||
(UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x.cast(b.dtype) if can_lossless_cast(x.dtype, a.dtype) else None),
|
||||
(UPat.var('x', dtypes.ints+(dtypes.index,)).cast(dtypes.ints+(dtypes.index,), name="a").cast(name="b"),
|
||||
(UPat.var('x', dtypes.ints+(dtypes.weakint,)).cast(dtypes.ints+(dtypes.weakint,), name="a").cast(name="b"),
|
||||
lambda x,a,b: x.cast(b.dtype) if a.dtype.min<=x.vmin and x.vmax<=a.dtype.max else None),
|
||||
# try to do math in int instead of long
|
||||
(UPat(GroupOp.Binary, src=(UPat.var("x", dtypes.long), UPat.var("y", dtypes.long)), name="u"), lambda u,x,y:
|
||||
x.cast(dtypes.int).alu(u.op, y.cast(dtypes.int)).cast(u.dtype) if not any(v.overflows(dtypes.int) for v in (u,x,y)) else None),
|
||||
((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
|
||||
((UPat.var("x", dtypes.weakint) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
|
||||
# only RANGE/IF/STORE/KERNEL have side effects
|
||||
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+
|
||||
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.BARRIER, Ops.END, Ops.UNROLL, Ops.LINEAR, Ops.BUFFERIZE}
|
||||
@@ -432,7 +432,7 @@ sym = symbolic+pm_simplify_valid+PatternMatcher([
|
||||
UPat.load(UPat(Ops.INDEX, name="index"))), allow_any_len=True, name="store"),
|
||||
lambda index, gate, alt, store: UOp.store(index.src[0].index(gate.where(index.src[1], UOp.invalid())), alt, *store.src[2:])),
|
||||
# fold gated LOAD/STORE
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat.const(dtypes.index, Invalid)).or_casted(),), allow_any_len=True, name="x"),
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat.const(dtypes.weakint, Invalid)).or_casted(),), allow_any_len=True, name="x"),
|
||||
lambda x: UOp(Ops.NOOP) if x.op is Ops.STORE else x.const_like(0)), # invalid store does nothing. invalid load produces 0
|
||||
(UPat(Ops.STORE, src=(UPat(), invalid_pat), allow_any_len=True), lambda i: UOp(Ops.NOOP)),
|
||||
# store of where with invalid -> gated store
|
||||
@@ -456,5 +456,5 @@ sym = symbolic+pm_simplify_valid+PatternMatcher([
|
||||
# ** combine terms (opinionated) **
|
||||
(-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
|
||||
# (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
|
||||
((UPat.var("x", dtypes.index) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
|
||||
((UPat.var("x", dtypes.weakint) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
|
||||
])
|
||||
|
||||
@@ -20,33 +20,33 @@ def create_bounded(name:str, vmin:int, vmax:int, z3ctx:z3.Context) -> tuple[z3.A
|
||||
return (s:=z3.Int(name, ctx=z3ctx)), (vmin <= s)&(s <= vmax)
|
||||
|
||||
z3_renderer = PatternMatcher([
|
||||
(UPat.var("cond").where(UPat.var("x"), UPat.const(dtypes.index, Invalid)), lambda x,cond,ctx: (ctx[1][x], ctx[1][cond])),
|
||||
(UPat.var("cond").where(UPat.var("x"), UPat.const(dtypes.weakint, Invalid)), lambda x,cond,ctx: (ctx[1][x], ctx[1][cond])),
|
||||
# variables
|
||||
(UPat(Ops.SPECIAL, name="x"), lambda x,ctx: create_bounded(x.arg, 0, ctx[1][x.src[0]]-1, ctx[0])),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0])),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x,ctx: create_bounded(x.render(simplify=False), 0, ctx[1][x.src[0]]-1, ctx[0])),
|
||||
# loads are variables bounded by the min/max of the dtype. non-pointer INDEX is also a LOAD
|
||||
(UPat((Ops.LOAD, Ops.INDEX), dtypes.ints+(dtypes.index,), name="x"), lambda x,ctx:
|
||||
(UPat((Ops.LOAD, Ops.INDEX), dtypes.ints+(dtypes.weakint,), name="x"), lambda x,ctx:
|
||||
create_bounded(f"load{len(ctx[1])}", x.dtype.min, x.dtype.max, ctx[0])),
|
||||
(UPat((Ops.LOAD, Ops.INDEX), dtypes.bool, name="x"), lambda x,ctx: (z3.Bool(f"load{len(ctx[1])}", ctx=ctx[0]), None)),
|
||||
# constants
|
||||
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x,ctx: (z3.Int("Invalid", ctx=ctx[0]), None)),
|
||||
(UPat(Ops.CONST, dtypes.ints+(dtypes.index,), name="x"), lambda x,ctx: (z3.IntVal(x.arg, ctx=ctx[0]), None)),
|
||||
(UPat(Ops.CONST, dtypes.ints+(dtypes.weakint,), name="x"), lambda x,ctx: (z3.IntVal(x.arg, ctx=ctx[0]), None)),
|
||||
(UPat(Ops.CONST, dtypes.bool, name="x"), lambda x,ctx: (z3.BoolVal(x.arg, ctx=ctx[0]), None)),
|
||||
# casts from floats create new variables
|
||||
(UPat(Ops.CAST, dtypes.ints+(dtypes.index,), src=(UPat(dtype=dtypes.floats),), name="x"), lambda x,ctx:
|
||||
(UPat(Ops.CAST, dtypes.ints+(dtypes.weakint,), src=(UPat(dtype=dtypes.floats),), name="x"), lambda x,ctx:
|
||||
create_bounded(f"cast{len(ctx[1])}", x.dtype.min, x.dtype.max, ctx[0])),
|
||||
# A comparison between floats introduces a new bool variable
|
||||
(UPat(GroupOp.Comparison, src=UPat(dtype=dtypes.floats), name="x"), lambda x,ctx: (z3.Bool(f"float_cmp{len(ctx[1])}", ctx=ctx[0]), None)),
|
||||
# casts from bool/int to int/bool
|
||||
(UPat(Ops.CAST, dtypes.ints+(dtypes.index,),src=(UPat.var("x", dtypes.bool),), name="c"), lambda x,c,ctx: (z3.If(ctx[1][x], 1, 0), None)),
|
||||
(UPat(Ops.CAST, dtypes.ints+(dtypes.index,), src=(UPat.var("x", dtypes.ints+(dtypes.index,)),), name="c"), lambda x,c,ctx: (ctx[1][x], None)),
|
||||
(UPat(Ops.CAST, dtypes.ints+(dtypes.weakint,),src=(UPat.var("x", dtypes.bool),), name="c"), lambda x,c,ctx: (z3.If(ctx[1][x], 1, 0), None)),
|
||||
(UPat(Ops.CAST, dtypes.ints+(dtypes.weakint,), src=(UPat.var("x", dtypes.ints+(dtypes.weakint,)),), name="c"), lambda x,c,ctx: (ctx[1][x], None)),
|
||||
(UPat(Ops.CAST, dtypes.bool, name="x"), lambda x,ctx: (ctx[1][x.src[0]]!=0, None)),
|
||||
(UPat(GroupOp.ALU, name="x"), lambda x,ctx: (z3_alu[x.op](*(ctx[1][s] for s in x.src)), None)),
|
||||
])
|
||||
|
||||
def uops_to_z3(solver:z3.Solver, *uops: UOp) -> list[z3.ExprRef]:
|
||||
lst = list(UOp.sink(*uops).toposort(gate=lambda x: x.dtype.scalar() in dtypes.ints+(dtypes.bool, dtypes.index) or x.op is Ops.SINK))[:-1]
|
||||
lst = list(UOp.sink(*uops).toposort(gate=lambda x: x.dtype.scalar() in dtypes.ints+(dtypes.bool, dtypes.weakint) or x.op is Ops.SINK))[:-1]
|
||||
z3map: dict[UOp, z3.ExprRef] = {}
|
||||
for u in lst:
|
||||
z3_rewritten = z3_renderer.rewrite(u, ctx=(solver.ctx, z3map))
|
||||
|
||||
@@ -74,7 +74,7 @@ const layoutUOp = (g, { graph, change }, opts) => {
|
||||
if (!opts.showIndexing) {
|
||||
for (const n of g.nodes()) {
|
||||
const node = g.node(n);
|
||||
if (node.label.includes("dtypes.index")) g.removeNode(n);
|
||||
if (node.label.includes("dtypes.weakint")) g.removeNode(n);
|
||||
}
|
||||
}
|
||||
if (!opts.showCallSrc || opts.callSrcMask.size > 0) {
|
||||
|
||||
@@ -104,7 +104,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
||||
# always exclude DEVICE/CONST/UNIQUE
|
||||
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u)
|
||||
if u.op is Ops.CONST and len(u.src) and u.src[0].op in {Ops.UNIQUE, Ops.LUNIQUE}: excluded.remove(u)
|
||||
if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.index and u is not x: excluded.add(u)
|
||||
if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.weakint and u is not x: excluded.add(u)
|
||||
if u.op is Ops.VECTORIZE and len(u.src) == 0: excluded.add(u)
|
||||
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
|
||||
if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)
|
||||
|
||||
Reference in New Issue
Block a user