enable test_resnet_half (#7141)

already worked so just fixed the test
This commit is contained in:
chenyu
2024-10-17 19:02:20 -04:00
committed by GitHub
parent 211d9753f8
commit 72ed66205d

View File

@@ -667,7 +667,6 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
"""
class TestSymbolicRealWorld(unittest.TestCase):
@unittest.expectedFailure
def test_resnet_half(self):
gidx0 = Variable("gidx0", 0, 3)
gidx1 = Variable("gidx1", 0, 127)
@@ -676,10 +675,12 @@ class TestSymbolicRealWorld(unittest.TestCase):
lidx4 = Variable("lidx4", 0, 1)
lidx5 = Variable("lidx5", 0, 15)
idx = ((((1+lidx5)%16)*49)+(((262145+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+-13151129600+lidx3)
print(idx.render())
idx:UOp = ((((1+lidx5)%16)*49)+(((262145+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+-13151129600+lidx3)
idx = graph_rewrite(idx, sym)
# print(idx.render())
# NOTE: this used to have 13,151,129,600 in the output which is out of int32 range.
assert idx.render() == "((((1+lidx5)%16)*49)+(((1+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+2207744+lidx3)"
assert idx.render() == \
"((((((((((lidx5+1)//16)*802816)+(((lidx5+1)%16)*49))+(gidx0*3211264))+(gidx1*784))+(gidx2*8))+(lidx4*100352))+lidx3)+2207744)"
class TestBounds(unittest.TestCase):
def test_unrolled_arange(self):