mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
with Tensor.train() (#1935)
* add with.train * remove the rest TODOs * fix pyflake * fix pyflake error * fix mypy
This commit is contained in:
130
test/external/external_test_opt.py
vendored
130
test/external/external_test_opt.py
vendored
@@ -171,73 +171,68 @@ class TestOpt(unittest.TestCase):
|
||||
assert ret.numpy()[0] == 33
|
||||
|
||||
def test_fold_batchnorm(self):
|
||||
# TODO: with Tensor.training
|
||||
Tensor.training = True
|
||||
img = Tensor.ones(1,32,4,4)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
with CLCache():
|
||||
img_bn = bn(img).realize()
|
||||
print(img_bn)
|
||||
assert len(CacheCollector.cache) == 3, f"optimizer didn't fold batchnorm, got {len(CacheCollector.cache)}"
|
||||
Tensor.training = False
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(1,32,4,4)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
with CLCache():
|
||||
img_bn = bn(img).realize()
|
||||
print(img_bn)
|
||||
assert len(CacheCollector.cache) == 3, f"optimizer didn't fold batchnorm, got {len(CacheCollector.cache)}"
|
||||
# Tensor.training = False
|
||||
|
||||
def test_fold_conv_sgd(self):
|
||||
# TODO: with Tensor.training
|
||||
Tensor.training = True
|
||||
img = Tensor.ones(2,3,4,4)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
opt = optim.SGD(get_parameters(c1))
|
||||
with CLCache():
|
||||
opt.zero_grad()
|
||||
c1(img).relu().sum().backward()
|
||||
opt.step()
|
||||
# TODO: this should be 4, but the sum output child stays around
|
||||
# with pushing_permutes it can be 3
|
||||
# TODO: broken with optim fixes
|
||||
assert len(CacheCollector.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(CacheCollector.cache)}"
|
||||
Tensor.training = False
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(2,3,4,4)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
opt = optim.SGD(get_parameters(c1))
|
||||
with CLCache():
|
||||
opt.zero_grad()
|
||||
c1(img).relu().sum().backward()
|
||||
opt.step()
|
||||
# TODO: this should be 4, but the sum output child stays around
|
||||
# with pushing_permutes it can be 3
|
||||
# TODO: broken with optim fixes
|
||||
assert len(CacheCollector.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(CacheCollector.cache)}"
|
||||
# Tensor.training = False
|
||||
|
||||
def test_fold_2convs_sgd(self):
|
||||
# TODO: with Tensor.training
|
||||
Tensor.training = True
|
||||
img = Tensor.ones(2,3,64,64)
|
||||
c1 = nn.Conv2d(3,16,3,bias=False)
|
||||
c2 = nn.Conv2d(16,32,3,bias=False)
|
||||
opt = optim.SGD(get_parameters([c1, c2]))
|
||||
with CLCache(allowed=9):
|
||||
opt.zero_grad()
|
||||
c2(c1(img).relu()).relu().sum().backward()
|
||||
opt.step()
|
||||
Tensor.training = False
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(2,3,64,64)
|
||||
c1 = nn.Conv2d(3,16,3,bias=False)
|
||||
c2 = nn.Conv2d(16,32,3,bias=False)
|
||||
opt = optim.SGD(get_parameters([c1, c2]))
|
||||
with CLCache(allowed=9):
|
||||
opt.zero_grad()
|
||||
c2(c1(img).relu()).relu().sum().backward()
|
||||
opt.step()
|
||||
# Tensor.training = False
|
||||
|
||||
def test_fold_4convs_sgd(self):
|
||||
# TODO: with Tensor.training
|
||||
Tensor.training = True
|
||||
img = Tensor.ones(2,3,64,64)
|
||||
c1 = nn.Conv2d(3,4,3,bias=False)
|
||||
c2 = nn.Conv2d(4,8,3,bias=False)
|
||||
c3 = nn.Conv2d(8,16,3,bias=False)
|
||||
c4 = nn.Conv2d(16,32,3,bias=False)
|
||||
opt = optim.SGD(get_parameters([c1, c2, c3, c4]))
|
||||
with CLCache(allowed=19):
|
||||
opt.zero_grad()
|
||||
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
|
||||
opt.step()
|
||||
Tensor.training = False
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(2,3,64,64)
|
||||
c1 = nn.Conv2d(3,4,3,bias=False)
|
||||
c2 = nn.Conv2d(4,8,3,bias=False)
|
||||
c3 = nn.Conv2d(8,16,3,bias=False)
|
||||
c4 = nn.Conv2d(16,32,3,bias=False)
|
||||
opt = optim.SGD(get_parameters([c1, c2, c3, c4]))
|
||||
with CLCache(allowed=19):
|
||||
opt.zero_grad()
|
||||
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
|
||||
opt.step()
|
||||
# Tensor.training = False
|
||||
|
||||
def test_fold_conv_batchnorm_sgd(self):
|
||||
# TODO: with Tensor.training
|
||||
Tensor.training = True
|
||||
img = Tensor.ones(1,3,4,4)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
opt = optim.SGD(get_parameters([c1, bn]))
|
||||
with CLCache(allowed=18): # this is too high
|
||||
img_bn = bn(c1(img)).elu().sum()
|
||||
opt.zero_grad()
|
||||
img_bn.backward()
|
||||
opt.step()
|
||||
Tensor.training = False
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(1,3,4,4)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
opt = optim.SGD(get_parameters([c1, bn]))
|
||||
with CLCache(allowed=18): # this is too high
|
||||
img_bn = bn(c1(img)).elu().sum()
|
||||
opt.zero_grad()
|
||||
img_bn.backward()
|
||||
opt.step()
|
||||
# Tensor.training = False
|
||||
|
||||
def test_fold_conv_batchnorm_notrain(self):
|
||||
img = Tensor.ones(1,3,8,8)
|
||||
@@ -250,15 +245,14 @@ class TestOpt(unittest.TestCase):
|
||||
assert len(CacheCollector.cache) == 1, f"optimizer didn't fold conv-batchnorm at test time, got {len(CacheCollector.cache)}"
|
||||
|
||||
def test_fold_conv_batchnorm(self):
|
||||
Tensor.training = True
|
||||
img = Tensor.ones(1,3,8,8)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
with CLCache():
|
||||
img_conv = bn(c1(img)).relu().realize()
|
||||
print(img_conv)
|
||||
assert len(CacheCollector.cache) == 4, f"optimizer didn't fold conv-batchnorm, got {len(CacheCollector.cache)}"
|
||||
Tensor.training = False
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(1,3,8,8)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
with CLCache():
|
||||
img_conv = bn(c1(img)).relu().realize()
|
||||
print(img_conv)
|
||||
assert len(CacheCollector.cache) == 4, f"optimizer didn't fold conv-batchnorm, got {len(CacheCollector.cache)}"
|
||||
|
||||
def test_fold_conv_elu(self):
|
||||
img = Tensor.ones(1,4,8,8)
|
||||
|
||||
17
test/external/graph_batchnorm.py
vendored
17
test/external/graph_batchnorm.py
vendored
@@ -4,15 +4,14 @@ from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d, BatchNorm2d, optim
|
||||
|
||||
def model_step(lm):
|
||||
Tensor.training = True
|
||||
x = Tensor.ones(8,12,128,256, requires_grad=False)
|
||||
optimizer = optim.SGD(get_parameters(lm), lr=0.001)
|
||||
loss = lm.forward(x).sum()
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
del x,loss
|
||||
optimizer.step()
|
||||
Tensor.training = False
|
||||
with Tensor.train():
|
||||
x = Tensor.ones(8,12,128,256, requires_grad=False)
|
||||
optimizer = optim.SGD(get_parameters(lm), lr=0.001)
|
||||
loss = lm.forward(x).sum()
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
del x,loss
|
||||
optimizer.step()
|
||||
|
||||
class TestBatchnorm(unittest.TestCase):
|
||||
def test_conv(self):
|
||||
|
||||
Reference in New Issue
Block a user