diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 1a883acb00..9f0ee23a7b 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -667,7 +667,6 @@ class TestSymbolicSymbolicOps(unittest.TestCase): """ class TestSymbolicRealWorld(unittest.TestCase): - @unittest.expectedFailure def test_resnet_half(self): gidx0 = Variable("gidx0", 0, 3) gidx1 = Variable("gidx1", 0, 127) @@ -676,10 +675,12 @@ class TestSymbolicRealWorld(unittest.TestCase): lidx4 = Variable("lidx4", 0, 1) lidx5 = Variable("lidx5", 0, 15) - idx = ((((1+lidx5)%16)*49)+(((262145+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+-13151129600+lidx3) - print(idx.render()) + idx:UOp = ((((1+lidx5)%16)*49)+(((262145+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+-13151129600+lidx3) + idx = graph_rewrite(idx, sym) + # print(idx.render()) # NOTE: this used to have 13,151,129,600 in the output which is out of int32 range. - assert idx.render() == "((((1+lidx5)%16)*49)+(((1+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+2207744+lidx3)" + assert idx.render() == \ + "((((((((((lidx5+1)//16)*802816)+(((lidx5+1)%16)*49))+(gidx0*3211264))+(gidx1*784))+(gidx2*8))+(lidx4*100352))+lidx3)+2207744)" class TestBounds(unittest.TestCase): def test_unrolled_arange(self):