mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-02 02:35:22 -05:00
Merge branch 'master' into fix_rng_merge
This commit is contained in:
@@ -2,7 +2,7 @@ import unittest, pickle, types
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, TinyJit, Variable, dtypes
|
||||
from tinygrad.helpers import GlobalCounters, ContextVar, Context
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, UOp, Ops
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, UOp
|
||||
|
||||
class TestPickle(unittest.TestCase):
|
||||
def test_pickle_code_object(self):
|
||||
@@ -45,10 +45,9 @@ class TestPickle(unittest.TestCase):
|
||||
t_values = t.numpy()
|
||||
del t # free buffers
|
||||
print("** post pickle")
|
||||
init = GlobalCounters.kernel_count
|
||||
t2:Tensor = pickle.loads(st)
|
||||
assert t2.uop.is_realized
|
||||
np.testing.assert_equal(t_values, t2.numpy())
|
||||
self.assertEqual(GlobalCounters.kernel_count-init, 0)
|
||||
|
||||
def test_pickle_realized_tensor_alt2(self):
|
||||
print("** init")
|
||||
@@ -70,14 +69,14 @@ class TestPickle(unittest.TestCase):
|
||||
def test_pickle_buffer_uop(self):
|
||||
t = Tensor.arange(4).realize()
|
||||
a = t.uop
|
||||
assert a.op is Ops.BUFFER
|
||||
self.assertIsNotNone(buffer:=a.realized)
|
||||
assert a.is_realized
|
||||
self.assertIsNotNone(buffer:=a.base.realized)
|
||||
s = pickle.dumps(a)
|
||||
# free buffers
|
||||
del a
|
||||
del buffer
|
||||
a2:UOp = pickle.loads(s)
|
||||
self.assertListEqual(a2.realized.as_buffer().cast("I").tolist(), [0, 1, 2, 3])
|
||||
self.assertListEqual(a2.base.realized.as_buffer().cast("I").tolist(), [0, 1, 2, 3])
|
||||
|
||||
def test_pickle_unrealized_tensor(self):
|
||||
t = Tensor.ones(10, 10)
|
||||
|
||||
Reference in New Issue
Block a user