last update for new symbolic [pr] (#6877)

This commit is contained in:
George Hotz
2024-10-04 14:58:51 +08:00
committed by GitHub
parent 7391376528
commit 738a5794a9
6 changed files with 12 additions and 9 deletions

View File

@@ -58,7 +58,7 @@ if __name__ == "__main__":
if DEBUG >= 1: print(t.__name__, rng)
rngs.append(rng)
if DEBUG >=1: print(expr)
space = list(itertools.product(range(u1.min, u1.max+1), range(u2.min, u2.max+1), range(u3.min, u3.max+1)))
space = list(itertools.product(range(u1.vmin, u1.vmax+1), range(u2.vmin, u2.vmax+1), range(u3.vmin, u3.vmax+1)))
volume = len(space)
for (v1, v2, v3) in random.sample(space, min(100, volume)):
v = [v1,v2,v3]

View File

@@ -76,7 +76,7 @@ class TestSymbolicVarVals(unittest.TestCase):
y = Variable("y", 1, 100).bind(4)
z = Variable("z", 1, 100).bind(5)
st = ShapeTracker.from_shape((x, 5, y)).shrink(((0, x), (z, z+1), (0, 3)))
assert st.views[-1].offset == y * z
self.assert_equal(st.views[-1].offset, y * z)
assert st.var_vals == {Variable("x", 1, 100): 3, Variable("y", 1, 100):4, Variable("z", 1, 100): 5}
def test_shrink_reshape(self):
@@ -126,6 +126,7 @@ class TestSymbolicReshapeFromContiguous(unittest.TestCase):
t = t.reshape(i, 4)
assert t.shape == (i, 4)
@unittest.skip("works now")
def test_reshape_into_symbols_bad_shape(self):
vi = Variable("i", 1, 10).bind(4)
# TODO: this never actually worked, it relied on lazy

View File

@@ -138,7 +138,7 @@ class TestProd(unittest.TestCase):
def test_ints(self): self.assertEqual(30, prod((2, 3, 5)))
def test_variable(self): self.assertEqual("(a*12)", prod((Variable("a", 1, 5), 3, 4)).render())
def test_variable_order(self): self.assertEqual("(a*12)", prod((3, 4, Variable("a", 1, 5))).render())
def test_num_nodes(self): self.assertEqual(NumNode(6), prod((NumNode(2), NumNode(3))))
def test_num_nodes(self): self.assertEqual(NumNode(6).render(), prod((NumNode(2), NumNode(3))).render())
class TestRoundUp(unittest.TestCase):
def test_round_up(self):

View File

@@ -182,8 +182,8 @@ class TestIndexExpressions2d(unittest.TestCase):
return [Variable(f"idx{i}", 0, d-1) for i,d in enumerate(shape)]
def check_bounds(self, expr, offset, numel):
assert expr.min >= offset
assert expr.max <= offset + numel - 1
assert expr.vmin >= offset
assert expr.vmax <= offset + numel - 1
def test_noop(self):
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):