jit: optimize before pickle (#9611)

* jit: optimize before pickle

* optimize weights

* fix

* mypy

* mypy2
This commit is contained in:
nimlgen
2025-03-28 19:06:09 +07:00
committed by GitHub
parent 392a311312
commit fa0ebbd237
3 changed files with 32 additions and 3 deletions

View File

@@ -605,5 +605,24 @@ class TestJitFree(unittest.TestCase):
fxn(Tensor([2]))
self.assertEqual(x.item(), 8)
def test_optimize_weights(self):
if not hasattr(Device[Device.DEFAULT].allocator, '_offset'): raise unittest.SkipTest("optimize_weights useless")
ext_tensor = Tensor([1,24,23,45,1])
ext_tensor_2 = Tensor([2,2,2,2,2])
@TinyJit
def fxn(x:Tensor):
out = (x*ext_tensor_2+ext_tensor).reshape(5,1).expand(5, 100).contiguous()
return out.sum()
for i in range(5):
out = fxn(Tensor([i,1,2,3,4]))
self.assertEqual(out.item(), 11400+200*i)
assert len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])) == 4
fxn.captured.optimize_weights()
assert len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])) == 2
out = fxn(Tensor([11,1,2,3,4]))
self.assertEqual(out.item(), 13600)
if __name__ == '__main__':
unittest.main()