From 54dc48aa47aa6e4df612d5b5f349f40b2aaa186a Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 22 Mar 2024 11:48:48 -0700 Subject: [PATCH] fix assign (#3878) * fix assign * remove terrible optimizer hack * oops, not realized assigns --- test/test_assign.py | 5 ++--- tinygrad/nn/optim.py | 6 +----- tinygrad/realize.py | 6 ++++++ 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/test/test_assign.py b/test/test_assign.py index 248460908f..9ced8858b3 100644 --- a/test/test_assign.py +++ b/test/test_assign.py @@ -100,13 +100,12 @@ class TestAssign(unittest.TestCase): new = a + times_a np.testing.assert_allclose(new.numpy(), 5) - @unittest.expectedFailure def test_assign_diamond_possible(self): a = Tensor.ones(4).contiguous().realize() times_a = a*3 a.assign(Tensor.full((4,), 2.).contiguous()) - new = a + (times_a+1).contiguous() - np.testing.assert_allclose(new.numpy(), 6) + new = a + (times_a-1).contiguous() + np.testing.assert_allclose(new.numpy(), 4) def test_assign_diamond_alt(self): a = Tensor.ones(4).contiguous().realize() diff --git a/tinygrad/nn/optim.py b/tinygrad/nn/optim.py index f18b29c2e7..f1b851bf7f 100644 --- a/tinygrad/nn/optim.py +++ b/tinygrad/nn/optim.py @@ -19,11 +19,7 @@ class Optimizer: for param in self.params: param.grad = None def realize(self, extra=None): - # NOTE: in extra is too late for most of the params due to issues with assign - #Tensor.corealize(extra + self.params + self.buffers if extra is not None else self.params + self.buffers) - # TODO: SUPER BROKEN THAT THIS FIXES IT - if extra is not None: Tensor.corealize(extra) - Tensor.corealize(self.params + self.buffers) + Tensor.corealize(extra + self.params + self.buffers if extra is not None else self.params + self.buffers) def step(self) -> None: raise NotImplementedError diff --git a/tinygrad/realize.py b/tinygrad/realize.py index c6c733ceec..97119d7c82 100644 --- a/tinygrad/realize.py +++ b/tinygrad/realize.py @@ -242,6 +242,7 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) # preschedule all buffers in realizes prescheduled = {x:_schedule_one(x, realizes, reduce_for_op) for x in realizes if x not in seen and x.realized is None and x.op is not LoadOps.CONST} + assign_targets = {x.srcs[1]:x for x in realizes if x.op is LoadOps.ASSIGN and x not in seen and x.realized is None} # breadth first ordering graph: DefaultDict[LazyBuffer,List[LazyBuffer]] = defaultdict(list) @@ -249,6 +250,11 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) queue: Deque[LazyBuffer] = deque() for buf in allbufs: if buf.realized or buf.op is LoadOps.CONST: continue + if buf in prescheduled: + for inp in prescheduled[buf].inputs: + if inp in assign_targets: + graph[buf].append(assign_targets[inp]) + in_degree[assign_targets[inp]] += 1 for x in buf.srcs: if x.base.realized or x.base.op is LoadOps.CONST: continue graph[x.base].append(buf)