mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
this fixes test_train_mnist memory, breaks everything else
This commit is contained in:
@@ -111,7 +111,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 63)
|
||||
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 104)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"CLANG", "GPU", "LLVM"}, "slow")
|
||||
def test_train_cifar(self):
|
||||
|
||||
@@ -212,7 +212,7 @@ class UOpMetaClass(type):
|
||||
return created
|
||||
|
||||
# some uops map to other stuff
|
||||
buffers:weakref.WeakKeyDictionary[UOp, Buffer] = weakref.WeakKeyDictionary() # this maps BUFFER uops to their device Buffers
|
||||
buffers:Dict[UOp, weakref.ReferenceType[Buffer]] = {} # this maps BUFFER uops to their device Buffers
|
||||
realized:weakref.WeakKeyDictionary[UOp, UOp] = weakref.WeakKeyDictionary() # this maps realized ops to a BUFFER uop
|
||||
forced_realize:weakref.WeakSet[UOp] = weakref.WeakSet()
|
||||
|
||||
@@ -223,9 +223,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
dtype:DType = dtypes.void
|
||||
src:Tuple[UOp, ...] = tuple()
|
||||
arg:Any = None
|
||||
def __del__(self):
|
||||
if self.op is Ops.BUFFER: self.buffer.ref(-1)
|
||||
del UOpMetaClass.ucache[(self.op, self.dtype, self.src, self.arg)]
|
||||
def __del__(self): del UOpMetaClass.ucache[(self.op, self.dtype, self.src, self.arg)]
|
||||
def __reduce__(self): return UOp, (self.op, self.dtype, self.src, self.arg)
|
||||
def replace(self, **kwargs) -> UOp:
|
||||
new_args = (kwargs.pop("op", self.op), kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src), kwargs.pop("arg", self.arg))
|
||||
@@ -440,14 +438,14 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
@property
|
||||
def buffer(self) -> Buffer:
|
||||
if self.base.realized is not None: return self.base.realized
|
||||
if (ret:=buffers.get(self)) is not None: return ret
|
||||
if (wret:=buffers.get(self)) is not None and (ret:=wret()) is not None: return ret
|
||||
if self.op is Ops.VIEW:
|
||||
assert unwrap(self.st).contiguous, "VIEW only works here if it's contiguous"
|
||||
return self.src[0].buffer
|
||||
assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
|
||||
from tinygrad.device import Buffer
|
||||
buffers[self] = ret = Buffer(*self.arg[1])
|
||||
return ret
|
||||
buffers[self] = weakref.ref(created:=Buffer(*self.arg[1]))
|
||||
return created
|
||||
@property
|
||||
def realized(self) -> Optional[Buffer]: return real_buf_uop.buffer if (real_buf_uop:=realized.get(self)) is not None else None
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user