fix bn folding issue, add new test

This commit is contained in:
George Hotz
2022-09-28 22:52:18 -04:00
parent a0d169eb59
commit 726cca78cd
3 changed files with 75 additions and 4 deletions

66
test/test_opt.py Normal file
View File

@@ -0,0 +1,66 @@
#!/usr/bin/env python
import os
os.environ["OPT"] = "2"
import gc
import numpy as np
import unittest
from tinygrad.tensor import Tensor, Device
from tinygrad import nn
from tinygrad.llops.ops_gpu import CL
class CLCache():
def __enter__(self):
gc.collect()
for x in [x for x in gc.get_objects() if isinstance(x, Tensor)]:
x.realize()
CL.CACHE = []
print("cache: entering")
def __exit__(self, type, value, traceback):
print(f"cache: exiting with size {len(CL.CACHE)}")
for prg, args in CL.CACHE:
e = prg.clprg(CL().cl_queue, *args)
CL.CACHE = None
Tensor.training = True
Tensor.no_grad = True
@unittest.skipUnless(Device.DEFAULT == Device.GPU, "Not Implemented")
class TestOpt(unittest.TestCase):
def test_muladd(self):
a,b,c = [Tensor.ones(2,2) for _ in range(3)]
with CLCache():
d = a * b + c
d.realize()
assert len(CL.CACHE) == 1, "optimizer didn't fold muladd"
np.testing.assert_allclose(d.numpy(), np.ones((2,2))*2, rtol=1e-5)
def test_fold_reduce_elementwise(self):
img = Tensor.ones(32)
addme = Tensor.ones(1)
with CLCache():
ret = img.sum() + addme
ret.realize()
assert len(CL.CACHE) == 1, "optimizer didn't fold reduce/elementwise"
assert ret.numpy()[0] == 33
def test_fold_batchnorm(self):
img = Tensor.ones(1,32,4,4)
bn = nn.BatchNorm2D(32, track_running_stats=False)
with CLCache():
img_bn = bn(img).realize()
print(img_bn)
assert len(CL.CACHE) == 3, "optimizer didn't fold batchnorm"
def test_fold_conv_elu(self):
img = Tensor.ones(1,4,8,8)
c1 = nn.Conv2d(4, 4, kernel_size=3)
c2 = nn.Conv2d(4, 4, kernel_size=3)
with CLCache():
img_conv = img.sequential([c1, Tensor.elu, c2, Tensor.elu]).realize()
print(img_conv)
assert len(CL.CACHE) == 2, "optimizer didn't fold conv/elu"
if __name__ == '__main__':
unittest.main()

View File

@@ -23,6 +23,8 @@ class BatchNorm2D:
batch_mean = x_detached.mean(axis=(0,2,3))
y = (x_detached - batch_mean.reshape(shape=[1, -1, 1, 1]))
batch_var = (y*y).mean(axis=(0,2,3))
batch_invstd = batch_var.add(self.eps)**-0.5
self.batch_invstd = None
# NOTE: wow, this is done all throughout training in most PyTorch models
if self.track_running_stats:
@@ -31,11 +33,12 @@ class BatchNorm2D:
self.num_batches_tracked += 1
else:
batch_mean, batch_var = self.running_mean, self.running_var
# NOTE: this can be precomputed for static inference. if you manually update running_var, you have to reset this
if getattr(self, "batch_invstd", None) is None:
self.batch_invstd = batch_var.add(self.eps)**-0.5
batch_invstd = self.batch_invstd
# NOTE: this can be precomputed for static inference. if you manually update running_var, you have to reset this
if Tensor.training or getattr(self, "batch_invstd", None) is None:
self.batch_invstd = batch_var.add(self.eps)**-0.5
return batch_normalize(x, self.weight, self.bias, batch_mean, self.batch_invstd)
return batch_normalize(x, self.weight, self.bias, batch_mean, batch_invstd)
# TODO: is this good weight init?
class Conv2d:

View File

@@ -157,6 +157,8 @@ def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer
# if there's *one* processing or reduce op in here, we can corealize it. we can corealize binary op siblings as well
# NOTE: if it references the same conv multiple times, they should already be merged by the dictionary
#for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())):
# print(k,x, len(x.children), [x for x in x.children])
psrcs : List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype in [ProcessingOps,ReduceOps] and x.realized is None and len(x.children) <= 1 and len(k.children) <= 1]
if len(psrcs) == 1 and MERGE_ONE_REDUCE_INTO_ELEMENTWISE and (self.device != "OPENCL" or self.shape[-1] == 4):
if psrcs[0][1].optype == ProcessingOps: