mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
@@ -667,7 +667,6 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
|
||||
"""
|
||||
|
||||
class TestSymbolicRealWorld(unittest.TestCase):
|
||||
@unittest.expectedFailure
|
||||
def test_resnet_half(self):
|
||||
gidx0 = Variable("gidx0", 0, 3)
|
||||
gidx1 = Variable("gidx1", 0, 127)
|
||||
@@ -676,10 +675,12 @@ class TestSymbolicRealWorld(unittest.TestCase):
|
||||
lidx4 = Variable("lidx4", 0, 1)
|
||||
lidx5 = Variable("lidx5", 0, 15)
|
||||
|
||||
idx = ((((1+lidx5)%16)*49)+(((262145+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+-13151129600+lidx3)
|
||||
print(idx.render())
|
||||
idx:UOp = ((((1+lidx5)%16)*49)+(((262145+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+-13151129600+lidx3)
|
||||
idx = graph_rewrite(idx, sym)
|
||||
# print(idx.render())
|
||||
# NOTE: this used to have 13,151,129,600 in the output which is out of int32 range.
|
||||
assert idx.render() == "((((1+lidx5)%16)*49)+(((1+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+2207744+lidx3)"
|
||||
assert idx.render() == \
|
||||
"((((((((((lidx5+1)//16)*802816)+(((lidx5+1)%16)*49))+(gidx0*3211264))+(gidx1*784))+(gidx2*8))+(lidx4*100352))+lidx3)+2207744)"
|
||||
|
||||
class TestBounds(unittest.TestCase):
|
||||
def test_unrolled_arange(self):
|
||||
|
||||
Reference in New Issue
Block a user