From bd40a26b8b87bb17adb0e5d8d6cbee21bd9c2e20 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 18 Sep 2024 06:06:03 -0400 Subject: [PATCH] image valid test case that current approach does not work (#6584) --- test/unit/test_image_valid.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/test/unit/test_image_valid.py b/test/unit/test_image_valid.py index 6ed9cca560..e09b782654 100644 --- a/test/unit/test_image_valid.py +++ b/test/unit/test_image_valid.py @@ -98,5 +98,29 @@ class TestValidSimplification(unittest.TestCase): self.assertEqual(render((1, 64, 4), valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), idx)), "read_imagef(data0, smp, (int2)((idx0+(-201)),0))") + def test_simplify4(self): + idx0 = Variable("idx0", 512) + data1_shape = (4, 64, 4) + alu2 = ((idx0*4+1)%32) + alu3 = ((idx0*4+2)%32) + alu4 = ((idx0*4+3)%32) + alu5 = (idx0*4%32) + alu8 = (idx0//8%32//4) + alu9 = idx0.lt(256) + + # TODO: simplify these, manual parsing is not going to work + # alu0 = (((idx0*4)+1)%32) + self.assertEqual(render(data1_shape, alu9, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (((alu8+(alu2*8))%64),(alu2//8)))), + "((idx0<256)?read_imagef(data0, smp, (int2)((((((idx0//8)%32)//4)+(alu0*8))%64),(alu0//8))):(float4)(0.0f,0.0f,0.0f,0.0f))") + # alu0 = (((idx0*4)+2)%32) + self.assertEqual(render(data1_shape, alu9, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (((alu8+(alu3*8))%64),(alu3//8)))), + "((idx0<256)?read_imagef(data0, smp, (int2)((((((idx0//8)%32)//4)+(alu0*8))%64),(alu0//8))):(float4)(0.0f,0.0f,0.0f,0.0f))") + # alu0 = (((idx0*4)+3)%32) + self.assertEqual(render(data1_shape, alu9, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (((alu8+(alu4*8))%64),(alu4//8)))), + "((idx0<256)?read_imagef(data0, smp, (int2)((((((idx0//8)%32)//4)+(alu0*8))%64),(alu0//8))):(float4)(0.0f,0.0f,0.0f,0.0f))") + # alu0 = ((idx0*4)%32) + self.assertEqual(render(data1_shape, alu9, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (((alu8+(alu5*8))%64),(alu5//8)))), + "((idx0<256)?read_imagef(data0, smp, (int2)((((((idx0//8)%32)//4)+(alu0*8))%64),(alu0//8))):(float4)(0.0f,0.0f,0.0f,0.0f))") + if __name__ == '__main__': unittest.main()