mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
last update for new symbolic [pr] (#6877)
This commit is contained in:
2
test/external/fuzz_symbolic.py
vendored
2
test/external/fuzz_symbolic.py
vendored
@@ -58,7 +58,7 @@ if __name__ == "__main__":
|
||||
if DEBUG >= 1: print(t.__name__, rng)
|
||||
rngs.append(rng)
|
||||
if DEBUG >=1: print(expr)
|
||||
space = list(itertools.product(range(u1.min, u1.max+1), range(u2.min, u2.max+1), range(u3.min, u3.max+1)))
|
||||
space = list(itertools.product(range(u1.vmin, u1.vmax+1), range(u2.vmin, u2.vmax+1), range(u3.vmin, u3.vmax+1)))
|
||||
volume = len(space)
|
||||
for (v1, v2, v3) in random.sample(space, min(100, volume)):
|
||||
v = [v1,v2,v3]
|
||||
|
||||
@@ -76,7 +76,7 @@ class TestSymbolicVarVals(unittest.TestCase):
|
||||
y = Variable("y", 1, 100).bind(4)
|
||||
z = Variable("z", 1, 100).bind(5)
|
||||
st = ShapeTracker.from_shape((x, 5, y)).shrink(((0, x), (z, z+1), (0, 3)))
|
||||
assert st.views[-1].offset == y * z
|
||||
self.assert_equal(st.views[-1].offset, y * z)
|
||||
assert st.var_vals == {Variable("x", 1, 100): 3, Variable("y", 1, 100):4, Variable("z", 1, 100): 5}
|
||||
|
||||
def test_shrink_reshape(self):
|
||||
@@ -126,6 +126,7 @@ class TestSymbolicReshapeFromContiguous(unittest.TestCase):
|
||||
t = t.reshape(i, 4)
|
||||
assert t.shape == (i, 4)
|
||||
|
||||
@unittest.skip("works now")
|
||||
def test_reshape_into_symbols_bad_shape(self):
|
||||
vi = Variable("i", 1, 10).bind(4)
|
||||
# TODO: this never actually worked, it relied on lazy
|
||||
|
||||
@@ -138,7 +138,7 @@ class TestProd(unittest.TestCase):
|
||||
def test_ints(self): self.assertEqual(30, prod((2, 3, 5)))
|
||||
def test_variable(self): self.assertEqual("(a*12)", prod((Variable("a", 1, 5), 3, 4)).render())
|
||||
def test_variable_order(self): self.assertEqual("(a*12)", prod((3, 4, Variable("a", 1, 5))).render())
|
||||
def test_num_nodes(self): self.assertEqual(NumNode(6), prod((NumNode(2), NumNode(3))))
|
||||
def test_num_nodes(self): self.assertEqual(NumNode(6).render(), prod((NumNode(2), NumNode(3))).render())
|
||||
|
||||
class TestRoundUp(unittest.TestCase):
|
||||
def test_round_up(self):
|
||||
|
||||
@@ -182,8 +182,8 @@ class TestIndexExpressions2d(unittest.TestCase):
|
||||
return [Variable(f"idx{i}", 0, d-1) for i,d in enumerate(shape)]
|
||||
|
||||
def check_bounds(self, expr, offset, numel):
|
||||
assert expr.min >= offset
|
||||
assert expr.max <= offset + numel - 1
|
||||
assert expr.vmin >= offset
|
||||
assert expr.vmax <= offset + numel - 1
|
||||
|
||||
def test_noop(self):
|
||||
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
|
||||
|
||||
@@ -430,6 +430,10 @@ sym = simple_pm+PatternMatcher([
|
||||
name="ld"), UPat.const(None, 0.0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
|
||||
# GEP/CAST const rules
|
||||
(UPat(UOps.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
|
||||
# ** combine terms (opinionated) **
|
||||
(-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
|
||||
# (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
|
||||
((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
|
||||
# ** self folding **
|
||||
# cast NOOP (NOTE: it's str to deal with PtrDType)
|
||||
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
|
||||
|
||||
@@ -727,15 +727,13 @@ simple_pm = PatternMatcher([
|
||||
((UPat.var("x") & UPat.var("x")), lambda x: x),
|
||||
((UPat.var("x") | UPat.var("x")), lambda x: x),
|
||||
(UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
|
||||
# group like
|
||||
((UPat.var("x") + UPat.var("y")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: (x+x*c)+y),
|
||||
# ** combine terms **
|
||||
(UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
|
||||
(UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
|
||||
# ** combine terms (opinionated) **
|
||||
(UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2
|
||||
((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3)
|
||||
(-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
|
||||
# (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
|
||||
((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
|
||||
# ** zero folding **
|
||||
(UPat.var("x") < UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False
|
||||
(UPat.var("x", dtype=dtypes.ints) != UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x != x -> False (only ints)
|
||||
|
||||
Reference in New Issue
Block a user