use assign_targets in LazyOp creation (#4568)

* start

* correct error

* this is possible

* document it
This commit is contained in:
qazal
2024-05-13 15:24:35 +08:00
committed by GitHub
parent b0fa97e176
commit 77aa8659f5
2 changed files with 25 additions and 10 deletions

View File

@@ -323,14 +323,28 @@ class TestAssign(unittest.TestCase):
b = Tensor.full((32, 32), 1.).contiguous().realize()
c = Tensor.full((32, 32), 2.).contiguous().realize()
# TODO: this is failing in cycle error, it should fail earlier.
with self.assertRaisesRegex(RuntimeError, "cycle"):
with self.assertRaisesRegex(RuntimeError, "contiguous"):
r = a.sum(axis=1)
b_perm = b.permute(1, 0)
b.assign(r + b)
c.assign(r + b_perm)
Tensor.realize(b, c)
def test_permuted_reduceop_multioutput_dual_use_possible(self):
a = Tensor.randn(32, 32, 32, dtype=dtypes.int).realize()
b = Tensor.arange(32 * 32).reshape(32, 32).realize()
c = Tensor.arange(32 * 32).reshape(32, 32).realize()
kc = GlobalCounters.kernel_count
r = a.sum(axis=1)
b_perm = b.permute(1, 0)
b.assign(r + b)
c.assign(r + b_perm.contiguous())
Tensor.realize(b, c)
assert GlobalCounters.kernel_count - kc == 2
np.testing.assert_equal(b.numpy(), a.numpy().sum(1) + np.arange(32 * 32).reshape(32, 32))
np.testing.assert_equal(c.numpy(), a.numpy().sum(1) + np.arange(32 * 32).reshape(32, 32).transpose(1, 0))
# TODO: is there a way to sneak in a permute such that it returns the wrong answer?
@unittest.skip("don't use output buffer, and mismatch dtype no longer supported")

View File

@@ -43,7 +43,7 @@ class _LBScheduleItem:
var_vals: Dict[Variable, int]
def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outbufs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker,
realizes:Dict[LazyBuffer, None], cache, assign_to:Optional[LazyBuffer]=None, assign_idx:Optional[int]=None) -> LazyOp:
realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer], cache) -> LazyOp:
"""recursively create a lazyop"""
if (buf, st) in cache: return cache[(buf, st)]
if buf != buf.base:
@@ -63,26 +63,26 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outbufs:Tuple[Laz
if buf.realized or (buf in realizes and buf not in outbufs):
unbound_st, st_var_vals = st.simplify().unbind()
var_vals.update(st_var_vals)
if assign_to is not None and buf is assign_to:
assert assign_idx is not None
if buf in assign_targets:
# can only assign to contiguous read+write buffer
if not unbound_st.contiguous:
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
if not (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and
ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask)):
raise RuntimeError(f"must be contiguous for assign {unbound_st}")
return LazyOp(BufferOps.LOAD, (), MemBuffer(assign_idx, buf.dtype, unbound_st))
return LazyOp(BufferOps.LOAD, (), MemBuffer(outbufs.index(assign_targets[buf]), buf.dtype, unbound_st))
if buf not in inputs: inputs.append(buf)
return LazyOp(BufferOps.LOAD, (), MemBuffer(len(outbufs)+inputs.index(buf), buf.dtype, unbound_st))
# if a CONTIGUOUS or ASSIGN made it all the way here, just skip it
if buf.op is LoadOps.CONTIGUOUS:
assert buf in outbufs
return _recursive_lazyop(buf.srcs[0], inputs, outbufs, var_vals, st, realizes, cache)
return _recursive_lazyop(buf.srcs[0], inputs, outbufs, var_vals, st, realizes, assign_targets, cache)
if buf.op is LoadOps.ASSIGN:
assert buf in outbufs
assert buf.srcs[1].base is buf.srcs[1], "assign must be to base"
assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
return _recursive_lazyop(buf.srcs[0], inputs, outbufs, var_vals, st, realizes, cache, assign_to=buf.srcs[1], assign_idx=outbufs.index(buf))
return _recursive_lazyop(buf.srcs[0], inputs, outbufs, var_vals, st, realizes, assign_targets, cache)
# if it's a reduce, we have to change the shapetracker
if buf.op in ReduceOps:
@@ -91,7 +91,7 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outbufs:Tuple[Laz
# otherwise we fuse it like normal
cache[(buf, st)] = ret = \
LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, outbufs, var_vals, st, realizes, cache, assign_to, assign_idx) for x in buf.srcs), buf.arg)
LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, outbufs, var_vals, st, realizes, assign_targets, cache) for x in buf.srcs), buf.arg)
return ret
def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem:
@@ -109,10 +109,11 @@ def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None]
else: ast = [LazyOp(op, (), out.arg)]
# multi output AST
else:
assign_targets = {x.srcs[1]:x for x in outs if x.op is LoadOps.ASSIGN}
for i, out in enumerate(outs):
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
output_view = out.arg[0] if out.op is LoadOps.ASSIGN and out.arg else output_st
lop = _recursive_lazyop(out, inputs, outs, var_vals, output_st, realizes, cache={})
lop = _recursive_lazyop(out, inputs, outs, var_vals, output_st, realizes, assign_targets, cache={})
output_view, vv = output_view.simplify().unbind()
if vv: var_vals.update(vv)
ast.append(LazyOp(BufferOps.STORE, (lop, ), MemBuffer(i, out.dtype, output_view)))