mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
add rendering to index (#12338)
This commit is contained in:
@@ -814,20 +814,6 @@ class TestShapeTrackerSize(unittest.TestCase):
|
||||
st = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).flip((True, True))
|
||||
self.assertEqual(st.real_size(), 100)
|
||||
|
||||
class TestRender(unittest.TestCase):
|
||||
def test_render(self):
|
||||
st = ShapeTracker.from_shape((2, 3))
|
||||
valid_idx = st.to_valid_uop()
|
||||
idx, valid = valid_idx.get_idx(), valid_idx.get_valid()
|
||||
self.assertEqual(idx.render(), "((ridx0*3)+ridx1)")
|
||||
self.assertEqual(valid.render(), "True")
|
||||
|
||||
st = st.pad(((0, 1), (0, 0)))
|
||||
valid_idx = st.to_valid_uop()
|
||||
idx, valid = valid_idx.get_idx(), valid_idx.get_valid()
|
||||
self.assertEqual(idx.render(), "((ridx0*3)+ridx1)")
|
||||
self.assertEqual(valid.render(), "(ridx0<2)")
|
||||
|
||||
class TestVariableShrink(unittest.TestCase):
|
||||
def test_shrink(self):
|
||||
st = ShapeTracker.from_shape((10,))
|
||||
|
||||
@@ -71,8 +71,8 @@ class TestValidIdxSimplification(unittest.TestCase):
|
||||
idx = ridx0+ridx1+ridx2+ridx3
|
||||
load = get_gated_load_uop(valid, idx)
|
||||
self.check(load,
|
||||
"(((ridx0+ridx1)+ridx2)+ridx3)",
|
||||
"((((ridx0*3)+ridx1)<8)&((((ridx2*3)+ridx3)%4)<2))")
|
||||
"(((r0+r1)+r2)+r3)",
|
||||
"((((r0*3)+r1)<8)&((((r2*3)+r3)%4)<2))")
|
||||
|
||||
def test_simplify_within_valid2(self):
|
||||
gidx0 = Special("gidx0", 56)
|
||||
@@ -85,8 +85,8 @@ class TestValidIdxSimplification(unittest.TestCase):
|
||||
ridx0 = Range(0, 2)
|
||||
v0 = ridx0<1
|
||||
v1 = ((ridx0*5+1)%6)<5
|
||||
self.assertEqual(simplify_valid(v0&v1).render(), "(ridx0<1)")
|
||||
self.assertEqual(simplify_valid(v1&v0).render(), "(ridx0<1)")
|
||||
self.assertEqual(simplify_valid(v0&v1).render(), "(r0<1)")
|
||||
self.assertEqual(simplify_valid(v1&v0).render(), "(r0<1)")
|
||||
|
||||
def test_valid_order_matters2(self):
|
||||
gidx0 = Special("gidx0", 13)
|
||||
@@ -128,8 +128,8 @@ class TestValidIdxSimplification(unittest.TestCase):
|
||||
valid = ((((((ridx2*2)+(ridx3*3))+3)%4)<2)!=True) # noqa: E712
|
||||
load = get_gated_load_uop(valid, idx)
|
||||
self.check(load,
|
||||
"(((ridx0*2)+(ridx3*-1))+1)",
|
||||
"(ridx2<1)")
|
||||
"(((r0*2)+(r3*-1))+1)",
|
||||
"(r2<1)")
|
||||
|
||||
def test_load_in_valid(self):
|
||||
# from FUSE_ARANGE=1 python test/test_ops.py TestOps.test_scatter_add
|
||||
@@ -154,8 +154,8 @@ class TestValidIdxSimplification(unittest.TestCase):
|
||||
valid = (ridx2<1)&(ridx1<6)
|
||||
load = get_gated_load_uop(valid, idx)
|
||||
self.check(load,
|
||||
"(ridx0*1568)",
|
||||
"((ridx2<1)&(ridx1<6))")
|
||||
"(r0*1568)",
|
||||
"((r2<1)&(r1<6))")
|
||||
|
||||
def test_valid_becomes_const1_z3(self):
|
||||
from z3 import Ints, Solver, And, If, Not, unsat
|
||||
@@ -195,7 +195,7 @@ class TestValidIdxSimplification(unittest.TestCase):
|
||||
load = get_gated_load_uop(valid, idx)
|
||||
self.check(load,
|
||||
"1",
|
||||
"((((ridx0+ridx1)<1)!=True)&(((ridx2+ridx3)<1)!=True))")
|
||||
"((((r0+r1)<1)!=True)&(((r2+r3)<1)!=True))")
|
||||
|
||||
def test_valid_with_non_const_rhs(self):
|
||||
ridx0 = Range(0, 2**16)
|
||||
@@ -205,8 +205,8 @@ class TestValidIdxSimplification(unittest.TestCase):
|
||||
idx = ridx0%1024
|
||||
load = get_gated_load_uop(valid, idx)
|
||||
self.check(load,
|
||||
"ridx0",
|
||||
"(ridx0<((ridx1*4)+ridx2))")
|
||||
"r0",
|
||||
"(r0<((r1*4)+r2))")
|
||||
|
||||
class TestImageSimplification(unittest.TestCase):
|
||||
def check(self, load, svalid, sidx0, sidx1):
|
||||
@@ -304,7 +304,7 @@ class TestImageSimplification(unittest.TestCase):
|
||||
idx = ((alu4+1530)%1536, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2))
|
||||
|
||||
load = get_load_image_uop(shape, valid, idx)
|
||||
self.check(load, None, "((((idx1*48)+(ridx2*6))+ridx0)+-6)", "(((idx2*2)+ridx1)+-1)")
|
||||
self.check(load, None, "((((idx1*48)+(r2*6))+r0)+-6)", "(((idx2*2)+r1)+-1)")
|
||||
|
||||
def test_openpilot_conv2(self):
|
||||
# conv in test/external/external_test_valid_remove.py
|
||||
@@ -325,7 +325,7 @@ class TestImageSimplification(unittest.TestCase):
|
||||
idx = ((alu3+765)%768, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2))
|
||||
load = get_load_image_uop(shape, valid, idx)
|
||||
|
||||
self.check(load, None, "((((idx1*24)+(ridx2*3))+ridx0)+-3)", "(((idx2*2)+ridx1)+-1)")
|
||||
self.check(load, None, "((((idx1*24)+(r2*3))+r0)+-3)", "(((idx2*2)+r1)+-1)")
|
||||
|
||||
def test_openpilot_conv3(self):
|
||||
# in openpilot 0.9.7
|
||||
@@ -346,9 +346,9 @@ class TestImageSimplification(unittest.TestCase):
|
||||
load = get_load_image_uop(shape, valid, idx)
|
||||
|
||||
self.check(load,
|
||||
"((((idx2*2)+ridx0)<11)&((((idx1*8)+ridx1)<3)!=True))",
|
||||
"(((idx0+((idx1*512)+(ridx1*64)))+832)%1024)",
|
||||
"((((idx2*2)+ridx0)+(((idx1+((ridx1+5)//8))+1)//2))+-4)")
|
||||
"((((idx2*2)+r0)<11)&((((idx1*8)+r1)<3)!=True))",
|
||||
"(((idx0+((idx1*512)+(r1*64)))+832)%1024)",
|
||||
"((((idx2*2)+r0)+(((idx1+((r1+5)//8))+1)//2))+-4)")
|
||||
|
||||
def test_simplify1(self):
|
||||
# idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1)
|
||||
@@ -424,7 +424,7 @@ class TestImageSimplification(unittest.TestCase):
|
||||
alu1 = ((idx2*1536)+(ridx4*768)+ridx3+(idx1*24)+(ridx5*3)+-771)//768
|
||||
valid = (((idx2+ridx4)<1)!=1)&(((idx1+ridx5)<1)!=1)
|
||||
load = get_load_image_uop((128, 768, 4), valid, (alu0, alu1))
|
||||
self.check(load, None, "((((idx1*24)+ridx3)+(ridx5*3))+-3)", "(((idx2*2)+ridx4)+-1)")
|
||||
self.check(load, None, "((((idx1*24)+r3)+(r5*3))+-3)", "(((idx2*2)+r4)+-1)")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -246,7 +246,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_range_mod_its_symbolic_bound(self):
|
||||
a = Variable("a", 1, 10, dtypes.index)
|
||||
ridx = UOp.range(a+2, 0)
|
||||
self.helper_test_variable(ridx%(a+2), 0, 11, "ridx0")
|
||||
self.helper_test_variable(ridx%(a+2), 0, 11, "r0")
|
||||
|
||||
def test_div_min_max(self):
|
||||
self.helper_test_variable(Variable("a", 2, 7) // 2, 1, 3, "(a//2)")
|
||||
|
||||
Reference in New Issue
Block a user