with Tensor.train() (#1935)

* add with.train

* remove the rest TODOs

* fix pyflake

* fix pyflake error

* fix mypy
This commit is contained in:
Yixiang Gao
2023-09-28 20:02:31 -05:00
committed by GitHub
parent 10f0dc0c85
commit 094d3d71be
14 changed files with 305 additions and 317 deletions

View File

@@ -154,12 +154,11 @@ class TestSchedule(unittest.TestCase):
#@unittest.skip("may want to reconsider this")
def test_fold_batchnorm(self):
Tensor.training = True
img = Tensor.empty(1,32,4,4)
bn = nn.BatchNorm2d(32, track_running_stats=False)
out = bn(img)
check_schedule(out, 3)
Tensor.training = False
with Tensor.train():
img = Tensor.empty(1,32,4,4)
bn = nn.BatchNorm2d(32, track_running_stats=False)
out = bn(img)
check_schedule(out, 3)
def test_fold_conv_relu(self):
c1 = nn.Conv2d(3,16,3)