less iterations for symbolic double for loops (#12006)

This commit is contained in:
chenyu
2025-09-04 15:09:17 -04:00
committed by GitHub
parent 70ce29b630
commit 8c720e8760
2 changed files with 24 additions and 24 deletions

View File

@@ -97,8 +97,8 @@ class TestSymbolicJit(unittest.TestCase):
jf = TinyJit(f)
a = Tensor.rand(10, 3)
b = Tensor.rand(10, 3)
for i in range(1, 5):
for j in range(1, 5):
for i in range(2, 5):
for j in range(2, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
symbolic = jf(a[:vi], b[:vj]).reshape(i+j, 3).numpy()
@@ -111,8 +111,8 @@ class TestSymbolicJit(unittest.TestCase):
jf = TinyJit(f)
a = Tensor.rand(3, 10)
b = Tensor.rand(3, 10)
for i in range(1, 5):
for j in range(1, 5):
for i in range(2, 5):
for j in range(2, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
symbolic = jf(a[:, :vi], b[:, :vj]).reshape(3, i+j).numpy()
@@ -125,8 +125,8 @@ class TestSymbolicJit(unittest.TestCase):
jf = TinyJit(f)
a = Tensor.rand(10, 3)
b = Tensor.rand(3, 10)
for i in range(1, 5):
for j in range(1, 5):
for i in range(2, 5):
for j in range(2, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
symbolic = jf(a[:vi, :], b[:, :vj]).reshape(i, j).numpy()
@@ -139,8 +139,8 @@ class TestSymbolicJit(unittest.TestCase):
jf = TinyJit(f)
a = Tensor.rand(10, 3)
b = Tensor.rand(3, 10)
for i in range(1, 5):
for j in range(1, 5):
for i in range(2, 5):
for j in range(2, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
symbolic = jf(a[:vj, :], b[:, :vi]).reshape(j, i).numpy()
@@ -245,8 +245,8 @@ class TestSymbolicJit(unittest.TestCase):
a = Tensor.rand(10, 10)
b = Tensor.rand(10, 10)
c = Tensor.rand(10, 10)
for i in range(1, 5):
for j in range(1, 5):
for i in range(2, 5):
for j in range(2, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
# axis = None
@@ -297,8 +297,8 @@ class TestSymbolicJit(unittest.TestCase):
a = Tensor.rand(10, 10)
b = Tensor.rand(10, 10)
c = Tensor.rand(10, 10)
for i in range(1, 5):
for j in range(1, 5):
for i in range(2, 5):
for j in range(2, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
# axis = None