mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
add dtypes.index (#12015)
* add dtypes.index * cast shape, stride and mask to dtypes.index in view.create * move pm_lower_index_dtype to ops * DEFINE_VAR is dtype.index by default * merge var_val_using_str * remove int from commutative * fix test_rewrite_map * change that to dtypes.index * change some int to index * shorten those * remove old cast in renderer * cleanup * change that back * add comment * delete comment * just delete those * view doesnt have to cast anymore * adjust comment
This commit is contained in:
@@ -26,6 +26,10 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from tinygrad import Tensor, dtypes, nn
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
MOCKGPU = getenv("MOCKGPU")
|
||||
|
||||
class TestNaNEdgeCases(unittest.TestCase):
|
||||
# we don't need more of these. it's unclear if torch's behavior is desired here
|
||||
@@ -167,34 +171,6 @@ class TestZeroFolding(unittest.TestCase):
|
||||
with self.assertRaises(RuntimeError):
|
||||
(x % x).numpy()
|
||||
|
||||
class TestArangeUOpValidationIssue(unittest.TestCase):
|
||||
# these fail with UOp verification error.
|
||||
# we don't need more of these involving arange
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_large_arange_sum(self):
|
||||
# Summing a huge arange should either succeed or raise a MemoryError.
|
||||
n = 2**31 + 3
|
||||
expected = (n - 1) * n // 2
|
||||
out = Tensor.arange(n).sum().item()
|
||||
self.assertEqual(out, expected)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_large_arange_index(self):
|
||||
# Indexing a huge arange should return the correct value instead of failing
|
||||
# with a UOp verification error.
|
||||
n = 2**31 + 3
|
||||
out = Tensor.arange(n)[0].item()
|
||||
self.assertEqual(out, 0)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_large_arange_permute(self):
|
||||
# Permuting a huge tensor should not trigger UOp verification failures.
|
||||
n = 2**31 + 3
|
||||
out = Tensor.arange(n).reshape(n, 1).permute(1, 0)
|
||||
self.assertEqual(out.shape, (1, n))
|
||||
out.realize()
|
||||
|
||||
class TestAssignIssues(unittest.TestCase):
|
||||
# these are good failures. i'm not sure we need more, but we need to fix these.
|
||||
|
||||
@@ -230,10 +206,8 @@ class TestUOpValidationIssue(unittest.TestCase):
|
||||
# these fail with UOp verification error.
|
||||
# we want more of these with diverse errors!
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipIf((not is_dtype_supported(dtypes.long)) or MOCKGPU, "hangs gpuocelot")
|
||||
def test_tensor_index_overflow(self):
|
||||
# Advanced indexing on tensors expanded past int32 should not error, but
|
||||
# tinygrad fails with a UOp verification error.
|
||||
val = Tensor([1])
|
||||
big = val.expand(2**31 + 3)
|
||||
idx = Tensor([0, 2**31 + 2])
|
||||
@@ -273,4 +247,4 @@ class TestEdgeCases(unittest.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
@@ -928,6 +928,7 @@ class TestIdxUpcast(unittest.TestCase):
|
||||
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 1, 2048).bind(32))
|
||||
|
||||
@unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported")
|
||||
@unittest.expectedFailure # bug in gpu dims limiting
|
||||
def test_int64_unsupported_overflow(self):
|
||||
with self.assertRaises(KeyError):
|
||||
self.do_op_then_assert(dtypes.long, 2048, 2048, 2048)
|
||||
|
||||
@@ -212,7 +212,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
|
||||
def test_where_same_fold(self):
|
||||
v = UOp.variable('tmp', 0, 1)
|
||||
c0 = UOp(Ops.CONST, dtypes.int, arg=0)
|
||||
c0 = UOp(Ops.CONST, dtypes.index, arg=0)
|
||||
vc = UOp(Ops.CMPNE, dtypes.bool, (v, c0))
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
out = UOp(Ops.WHERE, dtypes.float, (vc, c1, c1))
|
||||
@@ -398,7 +398,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 1)
|
||||
|
||||
def test_depth_2_const_fold(self):
|
||||
v = UOp.variable("tmp", 0, 1)
|
||||
v = UOp.variable("tmp", 0, 1, dtypes.int)
|
||||
c2 = UOp(Ops.CONST, dtypes.int, arg=2)
|
||||
c4 = UOp(Ops.CONST, dtypes.int, arg=4)
|
||||
vc = UOp(Ops.ADD, dtypes.int, (v, c2))
|
||||
@@ -417,6 +417,17 @@ class TestUOpGraph(unittest.TestCase):
|
||||
uops = to_uops_list([v.bitcast(dt)])
|
||||
self.assertEqual(len([x for x in uops if x.op is Ops.BITCAST]), 0, f"dtype = {dt}")
|
||||
|
||||
def test_load_idx_becomes_int(self):
|
||||
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 0)
|
||||
d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 1)
|
||||
l0 = UOp(Ops.LOAD, dtypes.long, (d0.index(UOp.const(dtypes.int, 0)),))
|
||||
idx = l0 * 600
|
||||
valid = (l0<-1).ne(True)&(l0<3000)
|
||||
l1 = UOp(Ops.LOAD, dtypes.long, (d1.index(idx, valid),))
|
||||
uops = to_uops_list([l1])
|
||||
for u in uops:
|
||||
if u.op is Ops.INDEX: self.assertEqual(u.src[1].dtype, dtypes.int)
|
||||
|
||||
def test_in_out_of_bounds_access(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
@@ -512,7 +523,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
def test_in_out_bounds_access_with_mask(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 42),), "gidx0")
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 42),), "gidx0")
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, (5<gidx0)&(gidx0<16)),))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<16),))
|
||||
to_uops_list([ld0, ld1])
|
||||
@@ -536,9 +547,9 @@ class TestUOpGraph(unittest.TestCase):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0)
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 42),), "gidx0")
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<8),))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<32)),))
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 42),), "gidx0")
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<8),)).cast(dtypes.index)
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<32)),)).cast(dtypes.index)
|
||||
to_uops_list([ld1])
|
||||
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<64)),))
|
||||
|
||||
@@ -97,19 +97,22 @@ class TestFoldingAndReduction(unittest.TestCase):
|
||||
|
||||
class TestModuloAndDivisionFolding(unittest.TestCase):
|
||||
def test_full_graph_rewrite_modulo_folding_with_define_var(self):
|
||||
x_var_uop = UOp.variable('x', 0, 100)
|
||||
# index dtype because div-mod rules only work on index
|
||||
x_var_uop = UOp.variable('x', 0, 100).cast(dtypes.index)
|
||||
optimized_mod_uop = apply_rewrite(((x_var_uop * 4) + 2) % 4)
|
||||
self.assertEqual(optimized_mod_uop.op, Ops.CONST)
|
||||
self.assertEqual(optimized_mod_uop.arg, 2)
|
||||
|
||||
def test_full_graph_rewrite_division_folding_with_define_var(self):
|
||||
n_var_uop = UOp.variable('n', 1, 1000)
|
||||
# index dtype because div-mod rules only work on index
|
||||
n_var_uop = UOp.variable('n', 1, 1000).cast(dtypes.index)
|
||||
optimized_div_uop = apply_rewrite((n_var_uop * 6) // 3)
|
||||
self.assertEqual(optimized_div_uop.op, Ops.MUL)
|
||||
self.assertEqual(optimized_div_uop.src[1].arg, 2)
|
||||
|
||||
def test_full_graph_rewrite_complex_mod_div_folding(self):
|
||||
k_var_uop = UOp.variable('k', 0, 50)
|
||||
# index dtype because div-mod rules only work on index
|
||||
k_var_uop = UOp.variable('k', 0, 50).cast(dtypes.index)
|
||||
optimized_div_uop = apply_rewrite(((k_var_uop * 12 + 8) % 6) // 2)
|
||||
self.assertEqual(optimized_div_uop.op, Ops.CONST)
|
||||
self.assertEqual(optimized_div_uop.arg, 1)
|
||||
@@ -126,8 +129,9 @@ class TestModuloAndDivisionFolding(unittest.TestCase):
|
||||
if opt.op is Ops.VECTORIZE: self.assertFalse(all_same(opt.src))
|
||||
|
||||
def test_full_graph_rewrite_modulo_large_divisor(self):
|
||||
# index dtype because div-mod rules only work on index
|
||||
x_var_uop = UOp.variable('x', 1, 5)
|
||||
self.assertIs(apply_rewrite(x_var_uop % 10), x_var_uop)
|
||||
self.assertIs(apply_rewrite(x_var_uop.cast(dtypes.index) % 10).render(simplify=False), x_var_uop.render(simplify=False))
|
||||
|
||||
def test_full_graph_rewrite_division_with_remainder(self):
|
||||
x_var_uop = UOp.variable('x', 7, 9)
|
||||
|
||||
@@ -46,8 +46,8 @@ class TestRewriteMap(unittest.TestCase):
|
||||
|
||||
def test_add_zero(self):
|
||||
# Build a small graph: add(0, add(const=0, const=5))
|
||||
zero_node = UOp.const(dtypes.int, 0)
|
||||
five_node = UOp.const(dtypes.int, 5)
|
||||
zero_node = UOp.const(dtypes.index, 0)
|
||||
five_node = UOp.const(dtypes.index, 5)
|
||||
inner_add = zero_node + five_node
|
||||
root_add = zero_node + inner_add
|
||||
|
||||
@@ -67,7 +67,7 @@ class TestRewriteMap(unittest.TestCase):
|
||||
Test rewriting neg(neg(5)) => 5 using symbolic.
|
||||
"""
|
||||
# In some versions of TinyGrad, you might do: (-(-five_node))
|
||||
five_node = UOp.const(dtypes.int, 5)
|
||||
five_node = UOp.const(dtypes.index, 5)
|
||||
# If your code allows UOp(...), do that; else you might do something like:
|
||||
# double_neg_five = -(-five_node)
|
||||
# But let's be explicit:
|
||||
@@ -85,8 +85,8 @@ class TestRewriteMap(unittest.TestCase):
|
||||
"""
|
||||
Combine both rewrites: add(0, neg(neg(5))) => add(0, 5) => 5
|
||||
"""
|
||||
zero_node = UOp.const(dtypes.int, 0)
|
||||
five_node = UOp.const(dtypes.int, 5)
|
||||
zero_node = UOp.const(dtypes.index, 0)
|
||||
five_node = UOp.const(dtypes.index, 5)
|
||||
neg_five = -five_node
|
||||
double_neg_five = -neg_five
|
||||
root_add = zero_node + double_neg_five
|
||||
@@ -103,7 +103,7 @@ class TestRewriteMap(unittest.TestCase):
|
||||
def test_multi_var_rewrites(self):
|
||||
x_var = UOp.variable('x', 0, 10)
|
||||
y_var = UOp.variable('y', -5, 5)
|
||||
zero_node = UOp.const(dtypes.int, 0)
|
||||
zero_node = UOp.const(dtypes.index, 0)
|
||||
|
||||
sum_with_zero = y_var + zero_node # (y + 0)
|
||||
combined = x_var + sum_with_zero # x + (y + 0)
|
||||
@@ -155,8 +155,8 @@ class TestRewriteMap(unittest.TestCase):
|
||||
x_var = UOp.variable('x', 1, 10)
|
||||
y_var = UOp.variable('y', -5, 5)
|
||||
z_var = UOp.variable('z', 0, 5)
|
||||
zero_node = UOp.const(dtypes.int, 0)
|
||||
one_node = UOp.const(dtypes.int, 1)
|
||||
zero_node = UOp.const(dtypes.index, 0)
|
||||
one_node = UOp.const(dtypes.index, 1)
|
||||
|
||||
# Build sub-expressions
|
||||
yz_sum = y_var + z_var # (y + z)
|
||||
|
||||
@@ -18,7 +18,7 @@ def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UO
|
||||
UOp(Ops.VECTORIZE, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4)
|
||||
))
|
||||
|
||||
def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.int32, (UOp.const(dtypes.int, nmax),), expr)
|
||||
def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, nmax),), expr)
|
||||
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
|
||||
def Range(n, nmax): return UOp.range(nmax, n)
|
||||
|
||||
|
||||
@@ -2,22 +2,20 @@
|
||||
import unittest, pickle, functools, math
|
||||
import z3
|
||||
|
||||
from tinygrad.dtype import dtypes, ConstType
|
||||
from tinygrad.dtype import dtypes, ConstType, DType
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.codegen.late.devectorizer import sym
|
||||
from tinygrad.helpers import Context
|
||||
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer
|
||||
from tinygrad import Variable
|
||||
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer, track_rewrites
|
||||
from tinygrad.uop.symbolic import sym
|
||||
from tinygrad.uop.spec import uops_to_z3
|
||||
|
||||
def render(self) -> tuple[str, ConstType, ConstType]:
|
||||
# NOTE: we need STORE so the ALU op has children
|
||||
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0)
|
||||
uops = full_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), self)).sink())
|
||||
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[1]
|
||||
return rewritten_uop.render(simplify=False), rewritten_uop.vmin, rewritten_uop.vmax
|
||||
@track_rewrites(name="simplify symbolic uop")
|
||||
def render(v) -> UOp:
|
||||
v_simplified = graph_rewrite(v, sym)
|
||||
return v_simplified
|
||||
|
||||
def uconst(val): return UOp.const(dtypes.int, val)
|
||||
def Variable(name: str, min_val: ConstType, max_val: ConstType, dtype: DType=dtypes.index): return UOp.variable(name,min_val,max_val,dtype)
|
||||
def uconst(val): return UOp.const(dtypes.index, val)
|
||||
def usum(ops): return functools.reduce(lambda x,y: x+y, ops)
|
||||
def uand(ops): return functools.reduce(lambda x,y: x*y, ops)
|
||||
|
||||
@@ -30,11 +28,12 @@ class TestSymbolicPickle(unittest.TestCase):
|
||||
|
||||
class TestSymbolic(unittest.TestCase):
|
||||
def helper_test_variable(self, v, n, m, s, test_z3:bool=True):
|
||||
v_simplified = render(v)
|
||||
if test_z3:
|
||||
solver = z3.Solver()
|
||||
expr, expr_simplified = uops_to_z3(solver, v, v.simplify())
|
||||
expr, expr_simplified = uops_to_z3(solver, v, v_simplified)
|
||||
self.assertEqual(solver.check(expr != expr_simplified), z3.unsat, "simplified expression not equal to original")
|
||||
rendered, nmin, nmax = render(v)
|
||||
rendered, nmin, nmax = v_simplified.render(simplify=False), v_simplified.vmin, v_simplified.vmax
|
||||
if isinstance(s, tuple): self.assertIn(rendered, s)
|
||||
else: self.assertEqual(rendered, s)
|
||||
self.assertEqual(nmin, n)
|
||||
@@ -111,7 +110,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable(-Variable("a", 0, 8), -8, 0, "(a*-1)")
|
||||
|
||||
def test_xor_0(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8) ^ 0, 0, 8, "a")
|
||||
self.helper_test_variable(Variable("a", 0, 8, dtypes.int) ^ 0, 0, 8, "a")
|
||||
|
||||
def test_add_1(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)+1, 1, 9, "(a+1)")
|
||||
@@ -209,12 +208,12 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.assertEqual((Variable("x", -10, 0)%Variable("y", 1, 10))._min_max, (-9, 0))
|
||||
|
||||
def test_range_div_its_symbolic_bound(self):
|
||||
a = Variable("a", 1, 10)
|
||||
a = Variable("a", 1, 10, dtypes.index)
|
||||
ridx0 = UOp.range(a+2, 0)
|
||||
self.helper_test_variable(ridx0//(a+2), 0, 0, "0")
|
||||
|
||||
def test_range_mod_its_symbolic_bound(self):
|
||||
a = Variable("a", 1, 10)
|
||||
a = Variable("a", 1, 10, dtypes.index)
|
||||
ridx = UOp.range(a+2, 0)
|
||||
self.helper_test_variable(ridx%(a+2), 0, 11, "ridx0")
|
||||
|
||||
@@ -463,8 +462,8 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -8, 0, "(((idx*10)//11)*-1)")
|
||||
|
||||
def test_nest_div_negative_factor(self):
|
||||
ridx0=UOp.variable("ridx0", 0, 9)
|
||||
ridx1=UOp.variable("ridx1", 0, 6)
|
||||
ridx0=Variable("ridx0", 0, 9)
|
||||
ridx1=Variable("ridx1", 0, 6)
|
||||
self.helper_test_variable(((((ridx0*-7)+ridx1)+63)//35), 0, 1, "(((ridx0//5)*-1)+1)")
|
||||
|
||||
def test_div_into_mod(self):
|
||||
@@ -533,8 +532,8 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable(x//y, 2, 2, "2")
|
||||
self.helper_test_variable(x%y, 0, 7, "(x+(y*-2))")
|
||||
# ensure all 4 corners are checked
|
||||
x = Variable("x", -10, 10)
|
||||
y = Variable("y", -8, 9)
|
||||
x = Variable("x", -10, 10, dtypes.int)
|
||||
y = Variable("y", -8, 9, dtypes.int)
|
||||
self.helper_test_variable(x//y, -2147483648, 2147483647, "(x//y)")
|
||||
self.helper_test_variable(x%y, -2147483648, 2147483647, "(x%y)")
|
||||
|
||||
@@ -587,6 +586,12 @@ class TestSymbolic(unittest.TestCase):
|
||||
unrolled_div = (gidx+2561)//4+(gidx+2562)//4+(gidx+2560)//4+(gidx+2559)//4
|
||||
self.helper_test_variable(unrolled_div, 2559, 5118, "(gidx+2559)")
|
||||
|
||||
def test_arange_unrolled4_with_cast(self):
|
||||
gidx = Variable("gidx", 0, 2559, dtypes.index)
|
||||
dt = dtypes.int
|
||||
unrolled_div = ((gidx+2561)//4 + 2).cast(dt)+((gidx+2562)//4).cast(dt)+((gidx+2560)//4).cast(dt)+((gidx+2559)//4).cast(dt)
|
||||
self.helper_test_variable(unrolled_div, 2561, 5120, "((int)(gidx)+2561)")
|
||||
|
||||
def test_arange_unrolled4_mul(self):
|
||||
gidx = Variable("gidx", 0, 2559)
|
||||
unrolled_div = 2*((gidx+2561)//4)+2*((gidx+2562)//4)+2*((gidx+2560)//4)+2*((gidx+2559)//4)
|
||||
@@ -688,10 +693,10 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable(-a<-b, False, True, "(b<a)")
|
||||
|
||||
def test_where_cast(self):
|
||||
s = Variable("s", 0, 3)
|
||||
s = Variable("s", 0, 3, dtypes.int)
|
||||
cond = s < 2
|
||||
a = Variable("a", 0, 3)
|
||||
b = Variable("b", 0, 3)
|
||||
a = Variable("a", 0, 3, dtypes.int)
|
||||
b = Variable("b", 0, 3, dtypes.int)
|
||||
expr = cond.where(a, b).cast(dtypes.half)
|
||||
|
||||
# TODO: copied from render, render does not support cast
|
||||
@@ -709,6 +714,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
expr = cond1.where(cond2.where(a, b), b)
|
||||
self.helper_test_variable(expr, 0, 3, "(a if ((s<6)&(2<s)) else b)")
|
||||
|
||||
@unittest.expectedFailure # needs simplify_valid which is not in render anymore
|
||||
def test_where_merge_branches2(self):
|
||||
cond1 = Variable("s", 0, 10) < 5
|
||||
cond2 = Variable("s", 0, 10) < 6
|
||||
@@ -738,8 +744,8 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable(a.trunc(), 1, 10, "a", test_z3=False)
|
||||
|
||||
def test_do_math_in_int32(self):
|
||||
a = Variable("a", 1, 10)
|
||||
b = Variable("b", 1, 10)
|
||||
a = Variable("a", 1, 10, dtypes.int)
|
||||
b = Variable("b", 1, 10, dtypes.int)
|
||||
self.helper_test_variable(a.cast(dtypes.long)+b.cast(dtypes.long), 2, 20, "(long)((a+b))")
|
||||
self.helper_test_variable(a.cast(dtypes.long)*b.cast(dtypes.long), 1, 100, "(long)((a*b))")
|
||||
|
||||
|
||||
@@ -142,7 +142,7 @@ class TestViz(BaseTestViz):
|
||||
|
||||
def test_const_node_visibility(self):
|
||||
a = UOp.variable("a", 0, 10)
|
||||
z = UOp.const(dtypes.int, 0)
|
||||
z = UOp.const(dtypes.index, 0)
|
||||
alu = a*z
|
||||
exec_rewrite(alu, [sym])
|
||||
lst = get_viz_list()
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Any, Callable
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype
|
||||
from tinygrad.uop.spec import type_verify
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
@@ -10,7 +10,7 @@ from tinygrad.renderer import Renderer
|
||||
from tinygrad.codegen.lowerer import pm_lowerer, get_index
|
||||
from tinygrad.codegen.quantize import pm_quant
|
||||
from tinygrad.codegen.gpudims import pm_add_gpudims
|
||||
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing
|
||||
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, cast_folding
|
||||
from tinygrad.uop.decompositions import get_late_rewrite_patterns
|
||||
from tinygrad.codegen.late.expander import migrate_indexing, expander, pm_pre_expander
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
|
||||
@@ -93,6 +93,9 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q
|
||||
supported_ops = tuple(opts.code_for_op.keys())
|
||||
extra_matcher = opts.extra_matcher if opts.extra_matcher is not None else PatternMatcher([])
|
||||
|
||||
# lower the index dtype to a concrete int
|
||||
ret.append(RewriteStep(pm_lower_index_dtype+cast_folding+load_store_indexing, lambda _: opts.device, name="lower all index dtypes"))
|
||||
|
||||
# optional pre matcher
|
||||
if opts.pre_matcher is not None: ret.append(RewriteStep(opts.pre_matcher, name="pre_matcher"))
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|No
|
||||
if max_sizes is not None and len(limited) > len(max_sizes): raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
|
||||
# try to split up dims: (a,) -> (b, c)
|
||||
if limited == dims: limited = _split_dims(dims, max_sizes) if max_sizes is not None else dims
|
||||
ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.int, (sint_to_uop(s),), (f"{prefix}{i}")) for i,s in enumerate(limited)]
|
||||
ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.index, (sint_to_uop(s),), (f"{prefix}{i}")) for i,s in enumerate(limited)]
|
||||
if len(limited) < len(dims):
|
||||
ret = []
|
||||
if (contraction:=get_contraction(dims, limited)) is None: raise AssertionError(f"get_contraction should not be None {dims=} {limited=}")
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace
|
||||
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite, GroupOp, identity_element
|
||||
from tinygrad.uop.symbolic import uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat
|
||||
from tinygrad.uop.symbolic import uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat, cast_folding
|
||||
from tinygrad.helpers import getenv, flatten, AMX, prod, partition
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
@@ -57,6 +57,9 @@ load_store_indexing = PatternMatcher([
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
|
||||
# index True is just Index
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat(Ops.CONST, arg=True))), lambda buf,start_idx: buf.index(start_idx)),
|
||||
# remove hanging cast
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast()),), lambda buf,idx: buf.index(idx)),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid)),
|
||||
# delete_redundant_gates (after expand)
|
||||
(UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")),
|
||||
UPat.var("val")), name="store", allow_any_len=True), delete_redundant_gates),
|
||||
@@ -316,12 +319,12 @@ def no_range(u:UOp) -> bool: return not any(x.op is Ops.RANGE for x in u.sparent
|
||||
|
||||
pm_reduce_collapse = PatternMatcher([
|
||||
# lift x+y out of reduce on lt
|
||||
((UPat.var("x")+UPat.var("y")) < UPat.var("c"), lambda x,y,c: (x < (c-y)) if no_range(y) and no_range(c) else None),
|
||||
((UPat.var("x")+UPat.var("y")).or_casted() < UPat.var("c"), lambda x,y,c: (x < (c.cast(y.dtype)-y)) if no_range(y) and no_range(c) else None),
|
||||
# lift x*y out of reduce
|
||||
((UPat.var("x")*UPat.var("y")) < UPat.var("c"),
|
||||
lambda x,y,c: (x < ((c+y-1) // y)) if no_range(y) and no_range(c) and y.vmin > 0 else None),
|
||||
# lift x+y out of reduce on ne
|
||||
((UPat.var("x")+UPat.var("y")) != UPat.var("c"), lambda x,y,c: (x != (c-y)) if no_range(y) and no_range(c) else None),
|
||||
((UPat.var("x")+UPat.var("y")).or_casted() != UPat.var("c"), lambda x,y,c: (x != (c.cast(y.dtype)-y)) if no_range(y) and no_range(c) else None),
|
||||
# fold the range
|
||||
((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(0, UPat.cvar("val")).reduce(arg=Ops.ADD, allow_any_len=True),
|
||||
lambda r,cut,val: (r.src[0]-cut).maximum(0).minimum(r.src[0]).cast(val.dtype) * val),
|
||||
@@ -351,7 +354,7 @@ pm_reduce_collapse = PatternMatcher([
|
||||
(UPat(Ops.VECTORIZE, name="inp").reduce(name="red", allow_any_len=True), no_vectorized_reduce),
|
||||
# index/load/where. TODO: this is more aggressive than needed
|
||||
(UPat((Ops.INDEX, Ops.LOAD, Ops.WHERE), name="alu"), no_vectorized_alu),
|
||||
])+sym
|
||||
])+sym+cast_folding
|
||||
|
||||
def reduce_collapse(red:UOp):
|
||||
included, not_included = partition(red.parents, lambda x: any(y in x.sparents for y in red.src[1:]))
|
||||
|
||||
@@ -151,7 +151,7 @@ def fix_group_for_reduce(x:UOp):
|
||||
pm_pre_expander = PatternMatcher([
|
||||
# rewrite UPCAST/UNROLL range to something to be expanded
|
||||
(UPat(Ops.RANGE, name="r"),
|
||||
lambda r: UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(s:=r.vmax+1), tuple(range(s))),), ((r.arg[0],s),)) \
|
||||
lambda r: UOp(Ops.UNROLL, r.dtype, (UOp.const(r.dtype.vec(s:=r.vmax+1), tuple(range(s))),), ((r.arg[0],s),)) \
|
||||
if r.arg[1] in {AxisType.UNROLL, AxisType.UPCAST} else None),
|
||||
# fix REDUCEs with UNROLLs
|
||||
(UPat(Ops.REDUCE, name="x"), fix_reduce_unroll),
|
||||
|
||||
@@ -298,6 +298,7 @@ class Compiled:
|
||||
|
||||
# TODO: move this to each Device
|
||||
def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
|
||||
if dtype == dtypes.index: return False
|
||||
if device is None: device = Device.DEFAULT
|
||||
if dtype == dtypes.bfloat16:
|
||||
if device == "METAL": return not CI
|
||||
|
||||
@@ -89,7 +89,7 @@ class dtypes:
|
||||
def is_float(x: DType) -> bool: return x.scalar() in dtypes.floats or isinstance(x, ImageDType)
|
||||
@staticmethod # static methods on top, or bool in the type info will refer to dtypes.bool
|
||||
@functools.cache
|
||||
def is_int(x: DType) -> bool: return x.scalar() in dtypes.ints
|
||||
def is_int(x: DType) -> bool: return x.scalar() in dtypes.ints + (dtypes.index,)
|
||||
@staticmethod
|
||||
@functools.cache
|
||||
def is_unsigned(x: DType) -> bool: return x.scalar() in dtypes.uints
|
||||
@@ -128,6 +128,7 @@ class dtypes:
|
||||
@staticmethod
|
||||
def fields() -> dict[str, DType]: return DTYPES_DICT
|
||||
void: Final[DType] = DType.new(-1, 0, "void", None)
|
||||
index: Final[DType] = DType.new(-1,100, "index", None)
|
||||
bool: Final[DType] = DType.new(0, 1, "bool", '?')
|
||||
int8: Final[DType] = DType.new(1, 1, "signed char", 'b')
|
||||
uint8: Final[DType] = DType.new(2, 1, "unsigned char", 'B')
|
||||
@@ -164,7 +165,7 @@ class dtypes:
|
||||
uints = (uint8, uint16, uint32, uint64)
|
||||
sints = (int8, int16, int32, int64)
|
||||
ints = uints + sints
|
||||
all = floats + ints + (bool,)
|
||||
all = floats + ints + (bool, index)
|
||||
|
||||
if (env_default_float := getenv("DEFAULT_FLOAT", "")):
|
||||
dtypes.default_float = getattr(dtypes, env_default_float.lower())
|
||||
@@ -186,11 +187,12 @@ def _get_recursive_parents(dtype:DType) -> set[DType]:
|
||||
return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64}
|
||||
@functools.cache
|
||||
def least_upper_dtype(*ds:DType) -> DType:
|
||||
return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0]
|
||||
return min(set.intersection(*[_get_recursive_parents(d.scalar()) for d in ds])) \
|
||||
if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0]
|
||||
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.default_float)
|
||||
|
||||
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void"))}
|
||||
INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void"}
|
||||
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void", "index"))}
|
||||
INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void", "index":"index"}
|
||||
|
||||
@functools.cache
|
||||
def can_safe_cast(dt0:DType, dt1:DType) -> bool:
|
||||
@@ -198,6 +200,7 @@ def can_safe_cast(dt0:DType, dt1:DType) -> bool:
|
||||
# https://numpy.org/doc/stable/reference/generated/numpy.can_cast.html
|
||||
if dt0 == dt1 or dt0 == dtypes.bool: return True
|
||||
match dt1:
|
||||
case dtypes.index: return dt0 in dtypes.ints
|
||||
case dtypes.double: return dt0 in (dtypes.float, dtypes.half, dtypes.bfloat16,
|
||||
dtypes.uint32, dtypes.uint16, dtypes.uint8, dtypes.int32, dtypes.int16, dtypes.int8)
|
||||
case dtypes.float: return dt0 in (dtypes.half, dtypes.bfloat16, dtypes.uint16, dtypes.uint8, dtypes.int16, dtypes.int8)
|
||||
@@ -315,4 +318,4 @@ def _to_torch_dtype(dtype:DType) -> 'torch.dtype'|None: # type: ignore [name-de
|
||||
except TypeError: return None
|
||||
@functools.cache
|
||||
def _from_torch_dtype(torchdtype:'torch.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
|
||||
return {v:k for k in dtypes.all if (v:=_to_torch_dtype(k)) is not None}[torchdtype]
|
||||
return {v:k for k in DTYPES_DICT.values() if (v:=_to_torch_dtype(k)) is not None}[torchdtype]
|
||||
|
||||
@@ -119,7 +119,7 @@ def map_reshape(idx:UOp, r:UOp):
|
||||
for s,src in list(zip(idx.shape, idx.src[1:]))[::-1]:
|
||||
to_sum.append(acc*src)
|
||||
acc *= s
|
||||
mish = sum(to_sum, start=UOp.const(dtypes.int, 0))
|
||||
mish = sum(to_sum, start=UOp.const(dtypes.index, 0))
|
||||
ret:list[UOp] = []
|
||||
for s in r.src[0].shape[::-1]:
|
||||
ret.append(mish % s) # NOTE: simplify will turn this to CONST
|
||||
@@ -186,7 +186,7 @@ def map_partial_contiguous(ctx:RangeifyContext, x:UOp, idx:UOp):
|
||||
ranges.append(idx.src[1+i])
|
||||
continue
|
||||
passthrough_idx.append(idx.src[1+i])
|
||||
ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.int, 0))
|
||||
ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.index, 0))
|
||||
new_ranges.append(ranges[-1])
|
||||
ret = x.src[0].index(*ranges).bufferize(*[x for x in new_ranges if x.op is not Ops.CONST], arg=x.device)
|
||||
return ret.index(*passthrough_idx)
|
||||
@@ -195,7 +195,7 @@ def map_contiguous(ctx:RangeifyContext, x:UOp):
|
||||
if x.arg is not None: return None
|
||||
ranges = []
|
||||
for s in x.shape[len(x.src)-1:]:
|
||||
ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.int, 0))
|
||||
ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.index, 0))
|
||||
ret = x.src[0].index(*ranges).bufferize(*x.src[1:], *[x for x in ranges if x.op is not Ops.CONST], arg=x.device)
|
||||
return ret.shrink(((0, prod(x.shape)),)).forced_reshape(x.shape)
|
||||
|
||||
|
||||
@@ -5,22 +5,8 @@ import functools
|
||||
from typing import Callable
|
||||
from tinygrad.helpers import merge_dicts, getenv
|
||||
from tinygrad.shape.view import View, unravel
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context, PatternMatcher, UPat, GroupOp
|
||||
from tinygrad.uop.symbolic import symbolic_flat, uop_given_valid, simplify_valid
|
||||
|
||||
# If a node overflow, its srcs need to be checked to see if this overflow is the result of an ALU operation,
|
||||
# or that the node simply inherits the dtype from srcs. Upcast is either `Ops.CAST`+`replace` or just `replace`.
|
||||
def handle_upcast(u: UOp) -> UOp|None:
|
||||
dtype = dtypes.int64.vec(u.dtype.count) if u.dtype.count > 1 else dtypes.int64
|
||||
# check for overflow, upcast this to int64
|
||||
if u.vmax > dtypes.max(dtypes.int) or u.vmin < dtypes.min(dtypes.int):
|
||||
return u.replace(dtype=dtype, src=tuple([x.cast(dtype) for x in u.src]))
|
||||
# if any inputs are int64 and this *doesn't* overflow, cast back to int
|
||||
if any(x.dtype == dtypes.int64 for x in u.src):
|
||||
return u.replace(dtype=dtype, src=tuple([x.cast(dtype) for x in u.src])).cast(u.dtype)
|
||||
return None
|
||||
pm_upcast = PatternMatcher([(UPat(GroupOp.ALU, dtype=dtypes.int, name="u"), handle_upcast),])
|
||||
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context
|
||||
|
||||
@functools.cache
|
||||
def views_to_indexed_uops(views: tuple[View, ...], _idxs:tuple[UOp, ...]|None=None) -> tuple[UOp, UOp]:
|
||||
@@ -34,8 +20,8 @@ def views_to_indexed_uops(views: tuple[View, ...], _idxs:tuple[UOp, ...]|None=No
|
||||
# simplify
|
||||
if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
|
||||
if (newidx:=uop_given_valid(valid, idx)) is not None: idx = newidx
|
||||
# symbolic again, upcast if needed
|
||||
return graph_rewrite(UOp.sink(idx, valid), symbolic_flat+pm_upcast, name="indexing sym @ 2").src
|
||||
# symbolic again
|
||||
return graph_rewrite(UOp.sink(idx, valid), symbolic_flat, name="indexing sym @ 2").src
|
||||
|
||||
@functools.cache
|
||||
def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[sint|None, ...]:
|
||||
|
||||
@@ -204,15 +204,15 @@ class View:
|
||||
|
||||
# Merge dimensions in vm2 if required.
|
||||
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
|
||||
idxs: list[UOp] = [UOp.variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
|
||||
merged_size, merged_term = 1, UOp.const(dtypes.int, 0)
|
||||
idxs: list[UOp] = [UOp.variable(f"idx{i}", 0, s-1, dtypes.index) for i,s in enumerate(vm1.shape)]
|
||||
merged_size, merged_term = 1, UOp.const(dtypes.index, 0)
|
||||
extents: list[tuple[sint, UOp]] = []
|
||||
for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
|
||||
merged_term += (sum([idxs[d1] * s1 for d1, s1 in term]) + o) * merged_size
|
||||
merged_size *= s
|
||||
if resolve(merged_term < merged_size, False) and resolve(0 <= merged_term, False):
|
||||
extents.append((merged_size, merged_term))
|
||||
merged_size, merged_term = 1, UOp.const(dtypes.int, 0)
|
||||
merged_size, merged_term = 1, UOp.const(dtypes.index, 0)
|
||||
if resolve(merged_term != 0): return None
|
||||
if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
|
||||
if (reshaped_vm2 := vm2.reshape(vm2_shape)) is None: return None
|
||||
|
||||
@@ -8,7 +8,7 @@ from tinygrad.dtype import _from_np_dtype, _to_np_dtype
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
||||
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray, RANGEIFY, FUSE_ATTENTION
|
||||
from tinygrad.gradient import compute_gradient
|
||||
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, MathTrait, identity_element, all_metadata
|
||||
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, MathTrait, identity_element, all_metadata, index_to_concrete_int
|
||||
from tinygrad.uop.spec import tensor_uop_spec, type_verify
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
@@ -139,6 +139,8 @@ class Tensor(MathTrait):
|
||||
# create a UOp from the different types of inputs
|
||||
if isinstance(data, UOp):
|
||||
assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported"
|
||||
# if data is dtype.index that means that this is a symbolic int and we need to lower it to something we can make a Tensor out of
|
||||
if data.dtype==dtypes.index: data = index_to_concrete_int(data)
|
||||
if data.op is Ops.BIND:
|
||||
var, val = data.unbind()
|
||||
# give the bound constant a device
|
||||
|
||||
@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
|
||||
from enum import Enum, auto
|
||||
from tinygrad.uop import Ops, GroupOp
|
||||
from tinygrad.uop.mathtraits import MathTrait
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype
|
||||
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
|
||||
from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey
|
||||
if TYPE_CHECKING:
|
||||
@@ -307,7 +307,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def range(end:sint, *arg):
|
||||
if len(arg) == 0: raise RuntimeError("range needs an arg")
|
||||
if len(arg) == 1: arg = arg+(AxisType.LOOP,)
|
||||
return UOp(Ops.RANGE, dtype=dtypes.int, src=(sint_to_uop(end),), arg=arg)
|
||||
return UOp(Ops.RANGE, dtype=dtypes.index, src=(sint_to_uop(end),), arg=arg)
|
||||
def r(self, op:Ops, axis:tuple[int, ...]):
|
||||
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
|
||||
if len(axis) == 0: return self
|
||||
@@ -482,7 +482,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
# *** uop Variable stuff ***
|
||||
|
||||
@staticmethod
|
||||
def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.int) -> UOp:
|
||||
def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.index) -> UOp:
|
||||
assert not isinstance(min_val, UOp) and not isinstance(max_val, UOp), f"can't create Variable {name} with {min_val}/{max_val}"
|
||||
return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
|
||||
@property
|
||||
@@ -573,7 +573,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.VCONST: return (min(self.arg), max(self.arg))
|
||||
if self.op is Ops.GEP: return self.src[0]._min_max
|
||||
# TODO: CAST to bool/unsigned is not monotone, still some case can be simplified
|
||||
if self.op is Ops.CAST and self.dtype in (dtypes.floats+dtypes.sints):
|
||||
if self.op is Ops.CAST and self.dtype in dtypes.floats+dtypes.sints+(dtypes.index,):
|
||||
return max(dtypes.min(self.dtype), self.src[0].vmin), min(self.src[0].vmax, dtypes.max(self.dtype))
|
||||
return dtypes.min(self.dtype), dtypes.max(self.dtype)
|
||||
|
||||
@@ -1025,7 +1025,30 @@ def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, na
|
||||
for k,v in input_map.items(): new_map[k] = new_map.get(v,v)
|
||||
return new_map
|
||||
|
||||
def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x
|
||||
def sint_to_uop(x:sint) -> UOp: return UOp.const(dtypes.index, x) if isinstance(x, int) else x.cast(dtypes.index)
|
||||
|
||||
def select_dtype(u): return (dtypes.long if u.overflows(dtypes.int32) else dtypes.int).vec(u.dtype.count)
|
||||
pm_lower_index_dtype = PatternMatcher([
|
||||
# There are no Unary ops at this point in symbolic, those are introduced later
|
||||
(UPat(GroupOp.Binary, dtypes.index, name="u", src=(UPat.var("x"), UPat.var("y"))), lambda u,x,y:
|
||||
x.cast(dt:=least_upper_dtype(select_dtype(u), x.dtype, y.dtype)).alu(u.op, y.cast(dt))),
|
||||
# comparison ops might now have different dtypes in their sources
|
||||
(UPat(GroupOp.Comparison, name="u", src=(UPat.var("x",dtypes.ints), UPat.var("y", dtypes.ints))), lambda u,x,y:
|
||||
x.cast(dt:=least_upper_dtype(x.dtype, y.dtype)).alu(u.op, y.cast(dt)) if x.dtype!=y.dtype else None),
|
||||
(UPat(Ops.WHERE, dtype=dtypes.index, src=(UPat.var("cond"), UPat.var("x"), UPat.var("y")), name="u"), lambda cond,u,x,y:
|
||||
cond.where(x.cast(dt:=least_upper_dtype(x.dtype, y.dtype)), y.cast(dt))),
|
||||
(UPat((Ops.CONST, Ops.VCONST), dtype=dtypes.index, name="u"), lambda u: u.replace(dtype=select_dtype(u))),
|
||||
(UPat((Ops.RANGE,), dtype=dtypes.index, src=(UPat.var("end")), name="r"), lambda ctx,r,end:
|
||||
r.replace(dtype=(dt:=select_dtype(r)), src=(end.cast(dt),))),
|
||||
(UPat(Ops.CAST, dtype=dtypes.index, src=(UPat.var("x", dtypes.ints),), name="u"), lambda u,x: x),
|
||||
(UPat(Ops.VECTORIZE, dtype=dtypes.index, name="u"), lambda u: u.replace(
|
||||
dtype=(dt:=least_upper_dtype(*[x.dtype for x in u.src])).vec(u.dtype.count), src=tuple(x.cast(dt) for x in u.src))),
|
||||
(UPat(Ops.VECTORIZE, dtype=dtypes.index, name="u"), lambda u: u.replace(dtype=(dt:=(dtypes.long if any(v.overflows(dtypes.int) for v in u.src)
|
||||
else dtypes.long)).vec(u.dtype.count),src=tuple(x.cast(dt) for x in u.src))),
|
||||
(UPat((Ops.SPECIAL,Ops.DEFINE_VAR), dtypes.index, name="u"), lambda u: u.replace(dtype=dtypes.int)),
|
||||
(UPat((Ops.BIND), dtypes.index, name="u"), lambda u: u.replace(dtype=u.src[0].dtype)),
|
||||
])
|
||||
def index_to_concrete_int(u:UOp): return graph_rewrite(u, pm_lower_index_dtype)
|
||||
|
||||
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
|
||||
|
||||
|
||||
@@ -28,17 +28,17 @@ try:
|
||||
# float loads only become a variable when they get cast to int/bool
|
||||
(UPat(Ops.LOAD, dtypes.ints, name="x"),
|
||||
lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0])))),
|
||||
(UPat(Ops.CONST, dtype=dtypes.ints+(dtypes.bool,), name="x"),
|
||||
(UPat(Ops.CONST, dtype=dtypes.ints+(dtypes.bool,dtypes.index), name="x"),
|
||||
lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx)))),
|
||||
# z3 can cast from bool to int automatically
|
||||
(UPat(Ops.CAST, dtype=dtypes.ints, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
|
||||
(UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
|
||||
(UPat(Ops.CAST, dtype=dtypes.bool, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], x.src[0].arg[1]!=0))),
|
||||
# if the source of the cast is not a noop it means that it is a float and so we create a new variable
|
||||
(UPat(Ops.CAST, dtype=dtypes.ints, name="x"), lambda x,ctx:
|
||||
(UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), name="x"), lambda x,ctx:
|
||||
UOp(Ops.NOOP, arg=(ctx[0], create_bounded(f"cast{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0])))),
|
||||
(UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x,ctx:
|
||||
UOp(Ops.NOOP, arg=(ctx[0], z3.Bool(f"cast{ctx[1].setdefault(x, len(ctx[1]))}",ctx=ctx[0].ctx)))),
|
||||
(UPat(Ops.XOR, src=UPat(Ops.NOOP), name="x"),
|
||||
(UPat(Ops.XOR, dtype=dtypes.ints+(dtypes.bool, ), src=UPat(Ops.NOOP), name="x"),
|
||||
lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], z3.BV2Int(z3_alu[x.op](*(z3.Int2BV(s.arg[1], x.dtype.itemsize*8) for s in x.src)))))),
|
||||
(UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], z3_alu[x.op](*(s.arg[1] for s in x.src))))),
|
||||
# A comparison between floats introduces a new bool variable
|
||||
@@ -95,7 +95,7 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
|
||||
(UPat(Ops.VIEW, src=(UPat.var("x"),)), lambda x: x.base.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN, Ops.CONST, Ops.DEVICE}),
|
||||
|
||||
# Tensor variable bindings
|
||||
(UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True),
|
||||
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.index,))), arg=None), lambda: True),
|
||||
|
||||
# Tensor const has a device and an unmasked ShapeTracker of stride 0
|
||||
# NOTE: variables in shape can cause multiple views in this ShapeTracker and other issues, see TestSymbolicJit.test_ones_sum
|
||||
@@ -120,6 +120,7 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
|
||||
# ***** uop type spec *****
|
||||
|
||||
def validate_index(idx:UOp, gate:UOp=UOp.const(dtypes.bool, True)):
|
||||
# TODO: check for overflow
|
||||
if IGNORE_OOB or isinstance(idx.dtype, ImageDType) or (sz := idx.src[0].ptrdtype.size) == -1: return True
|
||||
# We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask
|
||||
if 0<=idx.src[1].vmin and idx.src[1].vmax<sz: return True
|
||||
@@ -175,6 +176,9 @@ spec = PatternMatcher([
|
||||
|
||||
# **** new style load/store ****
|
||||
|
||||
# make sure all index dtypes have been lowered
|
||||
(UPat(GroupOp.All, dtype=dtypes.index), lambda: False),
|
||||
|
||||
# INDEX is used in new style load/store
|
||||
# INDEX takes a <buf, alu, gate?>
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines), UPat())), lambda: True),
|
||||
|
||||
@@ -26,7 +26,7 @@ symbolic_simple = PatternMatcher([
|
||||
# ** self folding **
|
||||
(UPat.var("x") + 0, lambda x: x), # x+0 -> x
|
||||
(UPat.var("x") * 1, lambda x: x), # x*1 -> x
|
||||
(UPat.var("x", dtype=dtypes.ints) ^ 0, lambda x: x), # x^0 -> x
|
||||
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)) ^ 0, lambda x: x), # x^0 -> x
|
||||
(UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1
|
||||
(UPat.var("x") // 1, lambda x: x), # x//1 -> x
|
||||
(UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x
|
||||
@@ -49,11 +49,11 @@ symbolic_simple = PatternMatcher([
|
||||
(UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
|
||||
(UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x),
|
||||
(UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, False), UPat.const(dtypes.bool, True)), lambda x: x.logical_not()),
|
||||
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool,)).trunc(), lambda x: x),
|
||||
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)).trunc(), lambda x: x),
|
||||
# ** zero folding **
|
||||
(UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False
|
||||
(UPat.var("x") % UPat.var("x"), lambda x: x.const_like(0)), # x%x -> 0
|
||||
(UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints),
|
||||
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)) != UPat.var("x"),
|
||||
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
|
||||
# x*0 -> 0 or 0*x -> 0
|
||||
# if x is nan or inf it should render the nan value.
|
||||
@@ -108,6 +108,7 @@ def fold_unrolled_divs(divs:UOp, denominator: int, fac=1) -> UOp|None:
|
||||
if fac!=1:
|
||||
if u.op is not Ops.MUL or u.src[1].op is not Ops.CONST or u.src[1].arg != fac: return None
|
||||
u = u.src[0]
|
||||
if u.op is Ops.CAST and u.src[0].dtype == dtypes.index: u = u.src[0]
|
||||
if not (u.op is Ops.IDIV and u.src[1].op is Ops.CONST): return None
|
||||
if denominator != u.src[1].arg: return None
|
||||
if (s0:=u.src[0]).vmin < 0: return None
|
||||
@@ -123,7 +124,7 @@ def fold_unrolled_divs(divs:UOp, denominator: int, fac=1) -> UOp|None:
|
||||
for i in range(denominator-len(seen_const)):
|
||||
if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i)
|
||||
if sorted(seen_const)==list(range(denominator)):
|
||||
return fac*ans
|
||||
return (fac*ans).cast(divs.dtype)
|
||||
return None
|
||||
|
||||
def lt_folding(x:UOp, c:int) -> UOp|None:
|
||||
@@ -270,10 +271,19 @@ gep_pushing = PatternMatcher([
|
||||
(UPat(Ops.WMMA, name="wmma").f(Ops.GEP, name="gep"), gep_through_wmma),
|
||||
])
|
||||
|
||||
cast_folding = PatternMatcher([
|
||||
(UPat.var('x', dtypes.ints+(dtypes.index,)).cast(dtypes.ints+(dtypes.index,), name="a").cast(name="b"),
|
||||
lambda x,a,b: x.cast(b.dtype) if a.dtype.min<=x.vmin and x.vmax<=a.dtype.max else None),
|
||||
# try to do math in int instead of long
|
||||
(UPat(GroupOp.Binary, src=(UPat.var("x", dtypes.long), UPat.var("y", dtypes.long)), name="u"), lambda u,x,y:
|
||||
x.cast(dtypes.int).alu(u.op, y.cast(dtypes.int)).cast(u.dtype) if not any(v.overflows(dtypes.int) for v in (u,x,y)) else None),
|
||||
((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
|
||||
])
|
||||
|
||||
commutative = PatternMatcher([
|
||||
# ** COMMUTATIVE flipping (only for ints) **
|
||||
# ** COMMUTATIVE flipping (only for index) **
|
||||
# NOTE: this can break merging vector math by only flipping some of them
|
||||
(UPat(GroupOp.Commutative, dtype=dtypes.int, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
|
||||
(UPat(GroupOp.Commutative, dtype=dtypes.index, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
|
||||
])
|
||||
|
||||
symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
@@ -289,10 +299,6 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
((UPat.var("y") + UPat.var("x")) + UPat.var("x"), lambda y,x: y+x*2),
|
||||
((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3) if x2 is not x3 else None), # (x/x2)/x3 -> x/(x2*x3)
|
||||
(-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c
|
||||
(UPat.var('x', dtypes.ints).cast(dtypes.ints, name="a").cast(name="b"),
|
||||
lambda x,a,b: x.cast(b.dtype) if a.dtype.min<=x.vmin and x.vmax<=a.dtype.max else None),
|
||||
(UPat(GroupOp.Binary, src=(UPat.var("x",dtypes.long), UPat.var("y", dtypes.long)), name="u"), lambda u,x,y:
|
||||
x.cast(dtypes.int).alu(u.op, y.cast(dtypes.int)).cast(u.dtype) if not any(v.overflows(dtypes.int) for v in (u,x,y)) else None),
|
||||
# a conditional with the same results either way is a noop, also fold const conditionals
|
||||
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
|
||||
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
|
||||
@@ -317,27 +323,27 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
((UPat.var("x") // UPat.cvar("c1")) // UPat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2)
|
||||
# ** lt **
|
||||
# c0*x<c1 for positive int c0,c1
|
||||
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints))<UPat.cvar("c1", vec=False),
|
||||
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.index))<UPat.cvar("c1", vec=False),
|
||||
lambda x,c0,c1: x<math.ceil(c1.arg/c0.arg) if c0.arg > 0 and c1.arg > 0 else None),
|
||||
# c0*x<c1 for negative int c0 and non-positive c1
|
||||
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints))<UPat.cvar("c1", vec=False),
|
||||
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.index))<UPat.cvar("c1", vec=False),
|
||||
lambda x,c0,c1: (-x)<(-(math.floor(-c1.arg/-c0.arg))) if c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
|
||||
# x//d<c
|
||||
((UPat.var("x", dtype=dtypes.ints)//UPat.cvar("d", vec=False))<UPat.cvar("c", vec=False),
|
||||
((UPat.var("x", dtype=dtypes.index)//UPat.cvar("d", vec=False))<UPat.cvar("c", vec=False),
|
||||
lambda x,d,c: (x<(c.arg*d.arg) if c.arg > 0 else x<(c.arg*d.arg-(d.arg-1))) if d.arg > 0 else None),
|
||||
# ** move add/mul consts to end (NOTE: this is still happening before constant folding) **
|
||||
((UPat.var("x") + UPat.cvar("c1")) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
|
||||
((UPat.var("x") * UPat.cvar("c1")) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
|
||||
# *** rules from symbolic ***
|
||||
# unrolled arange div folding
|
||||
((UPat() + UPat()//UPat.cvar("d", vec=False)).named("divs"), lambda divs,d: fold_unrolled_divs(divs, d.arg)),
|
||||
((UPat() + (UPat()//UPat.cvar("d", vec=False))*UPat.cvar("c")).named("divs"), lambda divs,d,c: fold_unrolled_divs(divs, d.arg, c.arg)),
|
||||
((UPat()+(UPat()//UPat.cvar("d", vec=False)).or_casted()).named("divs"), lambda divs,d: fold_unrolled_divs(divs, d.arg)),
|
||||
((UPat()+((UPat()//UPat.cvar("d", vec=False)).or_casted()*UPat.cvar("c"))).named("divs"), lambda divs,d,c: fold_unrolled_divs(divs, d.arg, c.arg)),
|
||||
# generic lt folding
|
||||
(UPat.var("x", dtypes.sints)<UPat.cvar("c", vec=False), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
|
||||
(UPat.var("x", dtypes.sints)*-1 < UPat.var("y", dtypes.sints)*-1, lambda x,y: y<x),
|
||||
(UPat.var("x", dtypes.index)<UPat.cvar("c", vec=False), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
|
||||
(UPat.var("x", dtypes.index)*-1 < UPat.var("y")*-1, lambda x,y: y<x),
|
||||
# canonicalize a simplex with positive coefficients > 0
|
||||
# not x < 1 -> X > 0
|
||||
((UPat.var("x", dtypes.ints)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None),
|
||||
((UPat.var("x", dtypes.index)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None),
|
||||
# ** div **
|
||||
# div folding
|
||||
((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)
|
||||
@@ -345,30 +351,28 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
# a range mod its own upper bound is just the range
|
||||
(UPat(Ops.RANGE, src=UPat.var("end"), name="r")%UPat.var("end"), lambda r,end: r),
|
||||
(UPat(Ops.RANGE, src=UPat.var("end"), name="r")//UPat.var("end"), lambda r,end: r.const_like(0)),
|
||||
(UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.var("y"))), cancel_divmod),
|
||||
(UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_binary_numerator),
|
||||
(UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_divmod_congruence),
|
||||
(UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), divide_by_gcd),
|
||||
(UPat(Ops.MOD, dtypes.sints, name="m", src=(UPat.var("x"), UPat.cvar("y", vec=False))), remove_nested_mod),
|
||||
(UPat((Ops.IDIV), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), nest_div_by_smallest_factor),
|
||||
(UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), simplify_remainder),
|
||||
(UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.var("y"))), cancel_divmod),
|
||||
(UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_binary_numerator),
|
||||
(UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_divmod_congruence),
|
||||
(UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), divide_by_gcd),
|
||||
(UPat(Ops.MOD, dtypes.index, name="m", src=(UPat.var("x"), UPat.cvar("y", vec=False))), remove_nested_mod),
|
||||
(UPat((Ops.IDIV), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), nest_div_by_smallest_factor),
|
||||
(UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), simplify_remainder),
|
||||
(UPat.var("x") // UPat.var("d"), lambda x,d: -(x//(-d)) if d.vmax < 0 else None),
|
||||
(UPat.var("x") // UPat.var("d"), lambda x,d: -((-x)//d) if x.vmax <=0 else None),
|
||||
((UPat.var("x", dtypes.sints)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False),
|
||||
((UPat.var("x", dtypes.index)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False),
|
||||
lambda x,c,n,d: (-(-(c.arg%d.arg + x - (d.arg-1))//d) + c.arg//d.arg) if x.vmax<=0 and n.vmin>=0 and d.arg>0 else None),
|
||||
# ** mod **
|
||||
# mod folding
|
||||
(UPat.var("x") % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <= 0 else None),
|
||||
(UPat.var("x") % UPat.var("d"), lambda x,d: (x%(-d)) if d.vmax < 0 else None),
|
||||
# up + x//c*c + x%c
|
||||
(UPat.var("up") + UPat.var("x", dtypes.ints)//UPat.cvar("c")*UPat.cvar("c") + UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda up,x,c: up+x),
|
||||
])+gep_pushing
|
||||
])+gep_pushing+cast_folding
|
||||
|
||||
symbolic_flat = symbolic+PatternMatcher([
|
||||
# ** combine terms (opinionated) **
|
||||
(-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
|
||||
# (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
|
||||
((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
|
||||
((UPat.var("x", dtypes.index) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
|
||||
])
|
||||
|
||||
# ******** we take a small aside to "simplify_valid" to rewrite valids ********
|
||||
@@ -400,6 +404,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
|
||||
# simplify uop given that valid is True
|
||||
for expr,v in bounds.items():
|
||||
v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1])
|
||||
expr = expr.substitute(load_subs) # make sure expr appears in same form in the uop
|
||||
# some expr has lower bound > upper bound -> valid is an empty set and we return None
|
||||
if v0 > v1: return None
|
||||
# whole node became a const
|
||||
|
||||
Reference in New Issue
Block a user