mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-16 01:26:29 -05:00
brute force VALIDHACK matching (#6575)
* brute force VALIDHACK matching * cleanup * 9700
This commit is contained in:
@@ -15,6 +15,7 @@ def render(image_shape, valid:UOp, idx:UOp) -> str:
|
||||
class TestRenderer(OpenCLRenderer):
|
||||
code_for_op = {**OpenCLRenderer().code_for_op, BinaryOps.IDIV: lambda a,b,dtype: f"({a}//{b})"}
|
||||
fxn = TestRenderer().render("", uops)
|
||||
# print(fxn)
|
||||
return fxn.split("float4 val0 = ")[1].split(";")[0]
|
||||
|
||||
def Variable(expr, nmax):
|
||||
@@ -60,5 +61,30 @@ class TestValidSimplification(unittest.TestCase):
|
||||
self.assertEqual(render((10, 10, 4), (gidx1).lt(5), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1+5))),
|
||||
"read_imagef(data0, smp, (int2)(gidx0,(gidx1+5)))")
|
||||
|
||||
def test_simplify1(self):
|
||||
# idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1)
|
||||
gidx = Variable("gidx", 512)
|
||||
valid = gidx.lt(488) & (-gidx).lt(-479)
|
||||
idx = ((gidx*3+18)%26, (gidx*3+18)//26-56)
|
||||
# alu0 is ((gidx*3)+18)
|
||||
self.assertEqual(render((1, 26, 4), valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), idx)),
|
||||
"read_imagef(data0, smp, (int2)(((gidx*3)+(-1438)),0))")
|
||||
|
||||
def test_simplify2(self):
|
||||
# from GPU=1 DEBUG=4 FORWARD_ONLY=1 IMAGE=2 python3 test/test_ops.py TestOps.test_simple_padding_conv2d
|
||||
lidx = Variable("lidx", 4)
|
||||
valid = lidx.lt(3) & (-lidx).lt(0)
|
||||
idx = ((lidx+1)%2, (lidx+1)//2-1)
|
||||
self.assertEqual(render((1, 2, 4), valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), idx)),
|
||||
"read_imagef(data0, smp, (int2)((lidx+(-1)),0))")
|
||||
|
||||
def test_simplify3(self):
|
||||
# from openpilot
|
||||
idx0 = Variable("idx0", 265)
|
||||
valid = (-idx0).lt(-200)
|
||||
idx = ((idx0+55)%64, (idx0+55)//64-4)
|
||||
self.assertEqual(render((1, 64, 4), valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), idx)),
|
||||
"read_imagef(data0, smp, (int2)((idx0+(-201)),0))")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user