last update for new symbolic [pr] (#6877)

This commit is contained in:
George Hotz
2024-10-04 14:58:51 +08:00
committed by GitHub
parent 7391376528
commit 738a5794a9
6 changed files with 12 additions and 9 deletions

View File

@@ -58,7 +58,7 @@ if __name__ == "__main__":
if DEBUG >= 1: print(t.__name__, rng)
rngs.append(rng)
if DEBUG >=1: print(expr)
space = list(itertools.product(range(u1.min, u1.max+1), range(u2.min, u2.max+1), range(u3.min, u3.max+1)))
space = list(itertools.product(range(u1.vmin, u1.vmax+1), range(u2.vmin, u2.vmax+1), range(u3.vmin, u3.vmax+1)))
volume = len(space)
for (v1, v2, v3) in random.sample(space, min(100, volume)):
v = [v1,v2,v3]

View File

@@ -76,7 +76,7 @@ class TestSymbolicVarVals(unittest.TestCase):
y = Variable("y", 1, 100).bind(4)
z = Variable("z", 1, 100).bind(5)
st = ShapeTracker.from_shape((x, 5, y)).shrink(((0, x), (z, z+1), (0, 3)))
assert st.views[-1].offset == y * z
self.assert_equal(st.views[-1].offset, y * z)
assert st.var_vals == {Variable("x", 1, 100): 3, Variable("y", 1, 100):4, Variable("z", 1, 100): 5}
def test_shrink_reshape(self):
@@ -126,6 +126,7 @@ class TestSymbolicReshapeFromContiguous(unittest.TestCase):
t = t.reshape(i, 4)
assert t.shape == (i, 4)
@unittest.skip("works now")
def test_reshape_into_symbols_bad_shape(self):
vi = Variable("i", 1, 10).bind(4)
# TODO: this never actually worked, it relied on lazy

View File

@@ -138,7 +138,7 @@ class TestProd(unittest.TestCase):
def test_ints(self): self.assertEqual(30, prod((2, 3, 5)))
def test_variable(self): self.assertEqual("(a*12)", prod((Variable("a", 1, 5), 3, 4)).render())
def test_variable_order(self): self.assertEqual("(a*12)", prod((3, 4, Variable("a", 1, 5))).render())
def test_num_nodes(self): self.assertEqual(NumNode(6), prod((NumNode(2), NumNode(3))))
def test_num_nodes(self): self.assertEqual(NumNode(6).render(), prod((NumNode(2), NumNode(3))).render())
class TestRoundUp(unittest.TestCase):
def test_round_up(self):

View File

@@ -182,8 +182,8 @@ class TestIndexExpressions2d(unittest.TestCase):
return [Variable(f"idx{i}", 0, d-1) for i,d in enumerate(shape)]
def check_bounds(self, expr, offset, numel):
assert expr.min >= offset
assert expr.max <= offset + numel - 1
assert expr.vmin >= offset
assert expr.vmax <= offset + numel - 1
def test_noop(self):
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):

View File

@@ -430,6 +430,10 @@ sym = simple_pm+PatternMatcher([
name="ld"), UPat.const(None, 0.0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
# GEP/CAST const rules
(UPat(UOps.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
# ** combine terms (opinionated) **
(-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
# (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
# ** self folding **
# cast NOOP (NOTE: it's str to deal with PtrDType)
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),

View File

@@ -727,15 +727,13 @@ simple_pm = PatternMatcher([
((UPat.var("x") & UPat.var("x")), lambda x: x),
((UPat.var("x") | UPat.var("x")), lambda x: x),
(UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
# group like
((UPat.var("x") + UPat.var("y")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: (x+x*c)+y),
# ** combine terms **
(UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
(UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
# ** combine terms (opinionated) **
(UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2
((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3)
(-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
# (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
# ** zero folding **
(UPat.var("x") < UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False
(UPat.var("x", dtype=dtypes.ints) != UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x != x -> False (only ints)