mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
jit: optimize before pickle (#9611)
* jit: optimize before pickle * optimize weights * fix * mypy * mypy2
This commit is contained in:
@@ -605,5 +605,24 @@ class TestJitFree(unittest.TestCase):
|
||||
fxn(Tensor([2]))
|
||||
self.assertEqual(x.item(), 8)
|
||||
|
||||
def test_optimize_weights(self):
|
||||
if not hasattr(Device[Device.DEFAULT].allocator, '_offset'): raise unittest.SkipTest("optimize_weights useless")
|
||||
|
||||
ext_tensor = Tensor([1,24,23,45,1])
|
||||
ext_tensor_2 = Tensor([2,2,2,2,2])
|
||||
@TinyJit
|
||||
def fxn(x:Tensor):
|
||||
out = (x*ext_tensor_2+ext_tensor).reshape(5,1).expand(5, 100).contiguous()
|
||||
return out.sum()
|
||||
for i in range(5):
|
||||
out = fxn(Tensor([i,1,2,3,4]))
|
||||
self.assertEqual(out.item(), 11400+200*i)
|
||||
assert len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])) == 4
|
||||
fxn.captured.optimize_weights()
|
||||
assert len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])) == 2
|
||||
|
||||
out = fxn(Tensor([11,1,2,3,4]))
|
||||
self.assertEqual(out.item(), 13600)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user