mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix bn folding issue, add new test
This commit is contained in:
66
test/test_opt.py
Normal file
66
test/test_opt.py
Normal file
@@ -0,0 +1,66 @@
|
||||
#!/usr/bin/env python
|
||||
import os
|
||||
os.environ["OPT"] = "2"
|
||||
|
||||
import gc
|
||||
import numpy as np
|
||||
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad import nn
|
||||
from tinygrad.llops.ops_gpu import CL
|
||||
|
||||
class CLCache():
|
||||
def __enter__(self):
|
||||
gc.collect()
|
||||
for x in [x for x in gc.get_objects() if isinstance(x, Tensor)]:
|
||||
x.realize()
|
||||
CL.CACHE = []
|
||||
print("cache: entering")
|
||||
def __exit__(self, type, value, traceback):
|
||||
print(f"cache: exiting with size {len(CL.CACHE)}")
|
||||
for prg, args in CL.CACHE:
|
||||
e = prg.clprg(CL().cl_queue, *args)
|
||||
CL.CACHE = None
|
||||
|
||||
Tensor.training = True
|
||||
Tensor.no_grad = True
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == Device.GPU, "Not Implemented")
|
||||
class TestOpt(unittest.TestCase):
|
||||
def test_muladd(self):
|
||||
a,b,c = [Tensor.ones(2,2) for _ in range(3)]
|
||||
with CLCache():
|
||||
d = a * b + c
|
||||
d.realize()
|
||||
assert len(CL.CACHE) == 1, "optimizer didn't fold muladd"
|
||||
np.testing.assert_allclose(d.numpy(), np.ones((2,2))*2, rtol=1e-5)
|
||||
|
||||
def test_fold_reduce_elementwise(self):
|
||||
img = Tensor.ones(32)
|
||||
addme = Tensor.ones(1)
|
||||
with CLCache():
|
||||
ret = img.sum() + addme
|
||||
ret.realize()
|
||||
assert len(CL.CACHE) == 1, "optimizer didn't fold reduce/elementwise"
|
||||
assert ret.numpy()[0] == 33
|
||||
|
||||
def test_fold_batchnorm(self):
|
||||
img = Tensor.ones(1,32,4,4)
|
||||
bn = nn.BatchNorm2D(32, track_running_stats=False)
|
||||
with CLCache():
|
||||
img_bn = bn(img).realize()
|
||||
print(img_bn)
|
||||
assert len(CL.CACHE) == 3, "optimizer didn't fold batchnorm"
|
||||
|
||||
def test_fold_conv_elu(self):
|
||||
img = Tensor.ones(1,4,8,8)
|
||||
c1 = nn.Conv2d(4, 4, kernel_size=3)
|
||||
c2 = nn.Conv2d(4, 4, kernel_size=3)
|
||||
with CLCache():
|
||||
img_conv = img.sequential([c1, Tensor.elu, c2, Tensor.elu]).realize()
|
||||
print(img_conv)
|
||||
assert len(CL.CACHE) == 2, "optimizer didn't fold conv/elu"
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -23,6 +23,8 @@ class BatchNorm2D:
|
||||
batch_mean = x_detached.mean(axis=(0,2,3))
|
||||
y = (x_detached - batch_mean.reshape(shape=[1, -1, 1, 1]))
|
||||
batch_var = (y*y).mean(axis=(0,2,3))
|
||||
batch_invstd = batch_var.add(self.eps)**-0.5
|
||||
self.batch_invstd = None
|
||||
|
||||
# NOTE: wow, this is done all throughout training in most PyTorch models
|
||||
if self.track_running_stats:
|
||||
@@ -31,11 +33,12 @@ class BatchNorm2D:
|
||||
self.num_batches_tracked += 1
|
||||
else:
|
||||
batch_mean, batch_var = self.running_mean, self.running_var
|
||||
# NOTE: this can be precomputed for static inference. if you manually update running_var, you have to reset this
|
||||
if getattr(self, "batch_invstd", None) is None:
|
||||
self.batch_invstd = batch_var.add(self.eps)**-0.5
|
||||
batch_invstd = self.batch_invstd
|
||||
|
||||
# NOTE: this can be precomputed for static inference. if you manually update running_var, you have to reset this
|
||||
if Tensor.training or getattr(self, "batch_invstd", None) is None:
|
||||
self.batch_invstd = batch_var.add(self.eps)**-0.5
|
||||
return batch_normalize(x, self.weight, self.bias, batch_mean, self.batch_invstd)
|
||||
return batch_normalize(x, self.weight, self.bias, batch_mean, batch_invstd)
|
||||
|
||||
# TODO: is this good weight init?
|
||||
class Conv2d:
|
||||
|
||||
@@ -157,6 +157,8 @@ def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer
|
||||
|
||||
# if there's *one* processing or reduce op in here, we can corealize it. we can corealize binary op siblings as well
|
||||
# NOTE: if it references the same conv multiple times, they should already be merged by the dictionary
|
||||
#for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())):
|
||||
# print(k,x, len(x.children), [x for x in x.children])
|
||||
psrcs : List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype in [ProcessingOps,ReduceOps] and x.realized is None and len(x.children) <= 1 and len(k.children) <= 1]
|
||||
if len(psrcs) == 1 and MERGE_ONE_REDUCE_INTO_ELEMENTWISE and (self.device != "OPENCL" or self.shape[-1] == 4):
|
||||
if psrcs[0][1].optype == ProcessingOps:
|
||||
|
||||
Reference in New Issue
Block a user