this fixes test_train_mnist memory, breaks everything else

This commit is contained in:
qazal
2024-12-07 17:37:42 +02:00
parent 4f98464bfc
commit dccfcbe068
2 changed files with 6 additions and 8 deletions

View File

@@ -111,7 +111,7 @@ class TestRealWorld(unittest.TestCase):
loss.backward()
optimizer.step()
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 63)
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 104)
@unittest.skipIf(CI and Device.DEFAULT in {"CLANG", "GPU", "LLVM"}, "slow")
def test_train_cifar(self):

View File

@@ -212,7 +212,7 @@ class UOpMetaClass(type):
return created
# some uops map to other stuff
buffers:weakref.WeakKeyDictionary[UOp, Buffer] = weakref.WeakKeyDictionary() # this maps BUFFER uops to their device Buffers
buffers:Dict[UOp, weakref.ReferenceType[Buffer]] = {} # this maps BUFFER uops to their device Buffers
realized:weakref.WeakKeyDictionary[UOp, UOp] = weakref.WeakKeyDictionary() # this maps realized ops to a BUFFER uop
forced_realize:weakref.WeakSet[UOp] = weakref.WeakSet()
@@ -223,9 +223,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
dtype:DType = dtypes.void
src:Tuple[UOp, ...] = tuple()
arg:Any = None
def __del__(self):
if self.op is Ops.BUFFER: self.buffer.ref(-1)
del UOpMetaClass.ucache[(self.op, self.dtype, self.src, self.arg)]
def __del__(self): del UOpMetaClass.ucache[(self.op, self.dtype, self.src, self.arg)]
def __reduce__(self): return UOp, (self.op, self.dtype, self.src, self.arg)
def replace(self, **kwargs) -> UOp:
new_args = (kwargs.pop("op", self.op), kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src), kwargs.pop("arg", self.arg))
@@ -440,14 +438,14 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@property
def buffer(self) -> Buffer:
if self.base.realized is not None: return self.base.realized
if (ret:=buffers.get(self)) is not None: return ret
if (wret:=buffers.get(self)) is not None and (ret:=wret()) is not None: return ret
if self.op is Ops.VIEW:
assert unwrap(self.st).contiguous, "VIEW only works here if it's contiguous"
return self.src[0].buffer
assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
from tinygrad.device import Buffer
buffers[self] = ret = Buffer(*self.arg[1])
return ret
buffers[self] = weakref.ref(created:=Buffer(*self.arg[1]))
return created
@property
def realized(self) -> Optional[Buffer]: return real_buf_uop.buffer if (real_buf_uop:=realized.get(self)) is not None else None
@property