dynamic assign idx (#3975)

This commit is contained in:
qazal
2024-03-28 22:59:32 +02:00
committed by GitHub
parent 2cfcb5623a
commit 2bfb1d3e39

View File

@@ -14,7 +14,7 @@ sys.setrecursionlimit(10000)
# recursively create a lazyop
def _recursive_lazyop(buf:LazyBuffer, membufs:List[LazyBuffer], var_vals:Dict[Variable, int], st:ShapeTracker,
realizes:Set[LazyBuffer], cache, first=True, assign_to:Optional[LazyBuffer]=None) -> LazyOp:
realizes:Set[LazyBuffer], cache, first=True, assign_to:Optional[LazyBuffer]=None, assign_idx:Optional[int]=None) -> LazyOp:
if (buf, st) in cache: return cache[(buf, st)]
if buf != buf.base:
st = buf.st + st
@@ -33,12 +33,13 @@ def _recursive_lazyop(buf:LazyBuffer, membufs:List[LazyBuffer], var_vals:Dict[Va
unbound_st, st_var_vals = st.simplify().unbind()
var_vals.update(st_var_vals)
if assign_to is not None and buf is assign_to:
assert assign_idx is not None
if not unbound_st.contiguous:
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
if not (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and
ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask)):
raise RuntimeError(f"must be contiguous for assign {unbound_st}")
return LazyOp(BufferOps.LOAD, (), MemBuffer(0, buf.dtype, unbound_st))
return LazyOp(BufferOps.LOAD, (), MemBuffer(assign_idx, buf.dtype, unbound_st))
if buf not in membufs: membufs.append(buf)
return LazyOp(BufferOps.LOAD, (), MemBuffer(membufs.index(buf), buf.dtype, unbound_st))
@@ -50,7 +51,7 @@ def _recursive_lazyop(buf:LazyBuffer, membufs:List[LazyBuffer], var_vals:Dict[Va
assert first
assert buf.srcs[1].base is buf.srcs[1], "assign must be to base"
assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
return _recursive_lazyop(buf.srcs[0], membufs, var_vals, st, realizes, cache, False, assign_to=buf.srcs[1])
return _recursive_lazyop(buf.srcs[0], membufs, var_vals, st, realizes, cache, False, assign_to=buf.srcs[1], assign_idx=membufs.index(buf))
# if it's a reduce, we have to change the shapetracker
if buf.op in ReduceOps:
@@ -59,7 +60,7 @@ def _recursive_lazyop(buf:LazyBuffer, membufs:List[LazyBuffer], var_vals:Dict[Va
# otherwise we fuse it like normal
cache[(buf, st)] = ret = \
LazyOp(buf.op, tuple(_recursive_lazyop(x, membufs, var_vals, st, realizes, cache, False, assign_to) for x in buf.srcs), buf.arg)
LazyOp(buf.op, tuple(_recursive_lazyop(x, membufs, var_vals, st, realizes, cache, False, assign_to, assign_idx) for x in buf.srcs), buf.arg)
return ret
def _schedule_one(out:LazyBuffer, realizes:Set[LazyBuffer], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> ScheduleItem: