diff --git a/test/test_assign.py b/test/test_assign.py index 57e4d9f815..c18d6301cf 100644 --- a/test/test_assign.py +++ b/test/test_assign.py @@ -323,14 +323,28 @@ class TestAssign(unittest.TestCase): b = Tensor.full((32, 32), 1.).contiguous().realize() c = Tensor.full((32, 32), 2.).contiguous().realize() - # TODO: this is failing in cycle error, it should fail earlier. - with self.assertRaisesRegex(RuntimeError, "cycle"): + with self.assertRaisesRegex(RuntimeError, "contiguous"): r = a.sum(axis=1) b_perm = b.permute(1, 0) b.assign(r + b) c.assign(r + b_perm) Tensor.realize(b, c) + def test_permuted_reduceop_multioutput_dual_use_possible(self): + a = Tensor.randn(32, 32, 32, dtype=dtypes.int).realize() + b = Tensor.arange(32 * 32).reshape(32, 32).realize() + c = Tensor.arange(32 * 32).reshape(32, 32).realize() + + kc = GlobalCounters.kernel_count + r = a.sum(axis=1) + b_perm = b.permute(1, 0) + b.assign(r + b) + c.assign(r + b_perm.contiguous()) + Tensor.realize(b, c) + assert GlobalCounters.kernel_count - kc == 2 + np.testing.assert_equal(b.numpy(), a.numpy().sum(1) + np.arange(32 * 32).reshape(32, 32)) + np.testing.assert_equal(c.numpy(), a.numpy().sum(1) + np.arange(32 * 32).reshape(32, 32).transpose(1, 0)) + # TODO: is there a way to sneak in a permute such that it returns the wrong answer? @unittest.skip("don't use output buffer, and mismatch dtype no longer supported") diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 41c1897e0d..3b611ee9cd 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -43,7 +43,7 @@ class _LBScheduleItem: var_vals: Dict[Variable, int] def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outbufs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker, - realizes:Dict[LazyBuffer, None], cache, assign_to:Optional[LazyBuffer]=None, assign_idx:Optional[int]=None) -> LazyOp: + realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer], cache) -> LazyOp: """recursively create a lazyop""" if (buf, st) in cache: return cache[(buf, st)] if buf != buf.base: @@ -63,26 +63,26 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outbufs:Tuple[Laz if buf.realized or (buf in realizes and buf not in outbufs): unbound_st, st_var_vals = st.simplify().unbind() var_vals.update(st_var_vals) - if assign_to is not None and buf is assign_to: - assert assign_idx is not None + if buf in assign_targets: + # can only assign to contiguous read+write buffer if not unbound_st.contiguous: # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine if not (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask)): raise RuntimeError(f"must be contiguous for assign {unbound_st}") - return LazyOp(BufferOps.LOAD, (), MemBuffer(assign_idx, buf.dtype, unbound_st)) + return LazyOp(BufferOps.LOAD, (), MemBuffer(outbufs.index(assign_targets[buf]), buf.dtype, unbound_st)) if buf not in inputs: inputs.append(buf) return LazyOp(BufferOps.LOAD, (), MemBuffer(len(outbufs)+inputs.index(buf), buf.dtype, unbound_st)) # if a CONTIGUOUS or ASSIGN made it all the way here, just skip it if buf.op is LoadOps.CONTIGUOUS: assert buf in outbufs - return _recursive_lazyop(buf.srcs[0], inputs, outbufs, var_vals, st, realizes, cache) + return _recursive_lazyop(buf.srcs[0], inputs, outbufs, var_vals, st, realizes, assign_targets, cache) if buf.op is LoadOps.ASSIGN: assert buf in outbufs assert buf.srcs[1].base is buf.srcs[1], "assign must be to base" assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}" - return _recursive_lazyop(buf.srcs[0], inputs, outbufs, var_vals, st, realizes, cache, assign_to=buf.srcs[1], assign_idx=outbufs.index(buf)) + return _recursive_lazyop(buf.srcs[0], inputs, outbufs, var_vals, st, realizes, assign_targets, cache) # if it's a reduce, we have to change the shapetracker if buf.op in ReduceOps: @@ -91,7 +91,7 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outbufs:Tuple[Laz # otherwise we fuse it like normal cache[(buf, st)] = ret = \ - LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, outbufs, var_vals, st, realizes, cache, assign_to, assign_idx) for x in buf.srcs), buf.arg) + LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, outbufs, var_vals, st, realizes, assign_targets, cache) for x in buf.srcs), buf.arg) return ret def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem: @@ -109,10 +109,11 @@ def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None] else: ast = [LazyOp(op, (), out.arg)] # multi output AST else: + assign_targets = {x.srcs[1]:x for x in outs if x.op is LoadOps.ASSIGN} for i, out in enumerate(outs): output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape) output_view = out.arg[0] if out.op is LoadOps.ASSIGN and out.arg else output_st - lop = _recursive_lazyop(out, inputs, outs, var_vals, output_st, realizes, cache={}) + lop = _recursive_lazyop(out, inputs, outs, var_vals, output_st, realizes, assign_targets, cache={}) output_view, vv = output_view.simplify().unbind() if vv: var_vals.update(vv) ast.append(LazyOp(BufferOps.STORE, (lop, ), MemBuffer(i, out.dtype, output_view)))