mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
with Tensor.train() (#1935)
* add with.train * remove the rest TODOs * fix pyflake * fix pyflake error * fix mypy
This commit is contained in:
@@ -154,12 +154,11 @@ class TestSchedule(unittest.TestCase):
|
||||
|
||||
#@unittest.skip("may want to reconsider this")
|
||||
def test_fold_batchnorm(self):
|
||||
Tensor.training = True
|
||||
img = Tensor.empty(1,32,4,4)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
out = bn(img)
|
||||
check_schedule(out, 3)
|
||||
Tensor.training = False
|
||||
with Tensor.train():
|
||||
img = Tensor.empty(1,32,4,4)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
out = bn(img)
|
||||
check_schedule(out, 3)
|
||||
|
||||
def test_fold_conv_relu(self):
|
||||
c1 = nn.Conv2d(3,16,3)
|
||||
|
||||
Reference in New Issue
Block a user