mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 14:43:57 -05:00
use assign_targets in LazyOp creation (#4568)
* start * correct error * this is possible * document it
This commit is contained in:
@@ -323,14 +323,28 @@ class TestAssign(unittest.TestCase):
|
||||
b = Tensor.full((32, 32), 1.).contiguous().realize()
|
||||
c = Tensor.full((32, 32), 2.).contiguous().realize()
|
||||
|
||||
# TODO: this is failing in cycle error, it should fail earlier.
|
||||
with self.assertRaisesRegex(RuntimeError, "cycle"):
|
||||
with self.assertRaisesRegex(RuntimeError, "contiguous"):
|
||||
r = a.sum(axis=1)
|
||||
b_perm = b.permute(1, 0)
|
||||
b.assign(r + b)
|
||||
c.assign(r + b_perm)
|
||||
Tensor.realize(b, c)
|
||||
|
||||
def test_permuted_reduceop_multioutput_dual_use_possible(self):
|
||||
a = Tensor.randn(32, 32, 32, dtype=dtypes.int).realize()
|
||||
b = Tensor.arange(32 * 32).reshape(32, 32).realize()
|
||||
c = Tensor.arange(32 * 32).reshape(32, 32).realize()
|
||||
|
||||
kc = GlobalCounters.kernel_count
|
||||
r = a.sum(axis=1)
|
||||
b_perm = b.permute(1, 0)
|
||||
b.assign(r + b)
|
||||
c.assign(r + b_perm.contiguous())
|
||||
Tensor.realize(b, c)
|
||||
assert GlobalCounters.kernel_count - kc == 2
|
||||
np.testing.assert_equal(b.numpy(), a.numpy().sum(1) + np.arange(32 * 32).reshape(32, 32))
|
||||
np.testing.assert_equal(c.numpy(), a.numpy().sum(1) + np.arange(32 * 32).reshape(32, 32).transpose(1, 0))
|
||||
|
||||
# TODO: is there a way to sneak in a permute such that it returns the wrong answer?
|
||||
|
||||
@unittest.skip("don't use output buffer, and mismatch dtype no longer supported")
|
||||
|
||||
@@ -43,7 +43,7 @@ class _LBScheduleItem:
|
||||
var_vals: Dict[Variable, int]
|
||||
|
||||
def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outbufs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker,
|
||||
realizes:Dict[LazyBuffer, None], cache, assign_to:Optional[LazyBuffer]=None, assign_idx:Optional[int]=None) -> LazyOp:
|
||||
realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer], cache) -> LazyOp:
|
||||
"""recursively create a lazyop"""
|
||||
if (buf, st) in cache: return cache[(buf, st)]
|
||||
if buf != buf.base:
|
||||
@@ -63,26 +63,26 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outbufs:Tuple[Laz
|
||||
if buf.realized or (buf in realizes and buf not in outbufs):
|
||||
unbound_st, st_var_vals = st.simplify().unbind()
|
||||
var_vals.update(st_var_vals)
|
||||
if assign_to is not None and buf is assign_to:
|
||||
assert assign_idx is not None
|
||||
if buf in assign_targets:
|
||||
# can only assign to contiguous read+write buffer
|
||||
if not unbound_st.contiguous:
|
||||
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
if not (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and
|
||||
ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask)):
|
||||
raise RuntimeError(f"must be contiguous for assign {unbound_st}")
|
||||
return LazyOp(BufferOps.LOAD, (), MemBuffer(assign_idx, buf.dtype, unbound_st))
|
||||
return LazyOp(BufferOps.LOAD, (), MemBuffer(outbufs.index(assign_targets[buf]), buf.dtype, unbound_st))
|
||||
if buf not in inputs: inputs.append(buf)
|
||||
return LazyOp(BufferOps.LOAD, (), MemBuffer(len(outbufs)+inputs.index(buf), buf.dtype, unbound_st))
|
||||
|
||||
# if a CONTIGUOUS or ASSIGN made it all the way here, just skip it
|
||||
if buf.op is LoadOps.CONTIGUOUS:
|
||||
assert buf in outbufs
|
||||
return _recursive_lazyop(buf.srcs[0], inputs, outbufs, var_vals, st, realizes, cache)
|
||||
return _recursive_lazyop(buf.srcs[0], inputs, outbufs, var_vals, st, realizes, assign_targets, cache)
|
||||
if buf.op is LoadOps.ASSIGN:
|
||||
assert buf in outbufs
|
||||
assert buf.srcs[1].base is buf.srcs[1], "assign must be to base"
|
||||
assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
|
||||
return _recursive_lazyop(buf.srcs[0], inputs, outbufs, var_vals, st, realizes, cache, assign_to=buf.srcs[1], assign_idx=outbufs.index(buf))
|
||||
return _recursive_lazyop(buf.srcs[0], inputs, outbufs, var_vals, st, realizes, assign_targets, cache)
|
||||
|
||||
# if it's a reduce, we have to change the shapetracker
|
||||
if buf.op in ReduceOps:
|
||||
@@ -91,7 +91,7 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outbufs:Tuple[Laz
|
||||
|
||||
# otherwise we fuse it like normal
|
||||
cache[(buf, st)] = ret = \
|
||||
LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, outbufs, var_vals, st, realizes, cache, assign_to, assign_idx) for x in buf.srcs), buf.arg)
|
||||
LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, outbufs, var_vals, st, realizes, assign_targets, cache) for x in buf.srcs), buf.arg)
|
||||
return ret
|
||||
|
||||
def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem:
|
||||
@@ -109,10 +109,11 @@ def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None]
|
||||
else: ast = [LazyOp(op, (), out.arg)]
|
||||
# multi output AST
|
||||
else:
|
||||
assign_targets = {x.srcs[1]:x for x in outs if x.op is LoadOps.ASSIGN}
|
||||
for i, out in enumerate(outs):
|
||||
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
|
||||
output_view = out.arg[0] if out.op is LoadOps.ASSIGN and out.arg else output_st
|
||||
lop = _recursive_lazyop(out, inputs, outs, var_vals, output_st, realizes, cache={})
|
||||
lop = _recursive_lazyop(out, inputs, outs, var_vals, output_st, realizes, assign_targets, cache={})
|
||||
output_view, vv = output_view.simplify().unbind()
|
||||
if vv: var_vals.update(vv)
|
||||
ast.append(LazyOp(BufferOps.STORE, (lop, ), MemBuffer(i, out.dtype, output_view)))
|
||||
|
||||
Reference in New Issue
Block a user