mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
dynamic assign idx (#3975)
This commit is contained in:
@@ -14,7 +14,7 @@ sys.setrecursionlimit(10000)
|
||||
|
||||
# recursively create a lazyop
|
||||
def _recursive_lazyop(buf:LazyBuffer, membufs:List[LazyBuffer], var_vals:Dict[Variable, int], st:ShapeTracker,
|
||||
realizes:Set[LazyBuffer], cache, first=True, assign_to:Optional[LazyBuffer]=None) -> LazyOp:
|
||||
realizes:Set[LazyBuffer], cache, first=True, assign_to:Optional[LazyBuffer]=None, assign_idx:Optional[int]=None) -> LazyOp:
|
||||
if (buf, st) in cache: return cache[(buf, st)]
|
||||
if buf != buf.base:
|
||||
st = buf.st + st
|
||||
@@ -33,12 +33,13 @@ def _recursive_lazyop(buf:LazyBuffer, membufs:List[LazyBuffer], var_vals:Dict[Va
|
||||
unbound_st, st_var_vals = st.simplify().unbind()
|
||||
var_vals.update(st_var_vals)
|
||||
if assign_to is not None and buf is assign_to:
|
||||
assert assign_idx is not None
|
||||
if not unbound_st.contiguous:
|
||||
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
if not (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and
|
||||
ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask)):
|
||||
raise RuntimeError(f"must be contiguous for assign {unbound_st}")
|
||||
return LazyOp(BufferOps.LOAD, (), MemBuffer(0, buf.dtype, unbound_st))
|
||||
return LazyOp(BufferOps.LOAD, (), MemBuffer(assign_idx, buf.dtype, unbound_st))
|
||||
if buf not in membufs: membufs.append(buf)
|
||||
return LazyOp(BufferOps.LOAD, (), MemBuffer(membufs.index(buf), buf.dtype, unbound_st))
|
||||
|
||||
@@ -50,7 +51,7 @@ def _recursive_lazyop(buf:LazyBuffer, membufs:List[LazyBuffer], var_vals:Dict[Va
|
||||
assert first
|
||||
assert buf.srcs[1].base is buf.srcs[1], "assign must be to base"
|
||||
assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
|
||||
return _recursive_lazyop(buf.srcs[0], membufs, var_vals, st, realizes, cache, False, assign_to=buf.srcs[1])
|
||||
return _recursive_lazyop(buf.srcs[0], membufs, var_vals, st, realizes, cache, False, assign_to=buf.srcs[1], assign_idx=membufs.index(buf))
|
||||
|
||||
# if it's a reduce, we have to change the shapetracker
|
||||
if buf.op in ReduceOps:
|
||||
@@ -59,7 +60,7 @@ def _recursive_lazyop(buf:LazyBuffer, membufs:List[LazyBuffer], var_vals:Dict[Va
|
||||
|
||||
# otherwise we fuse it like normal
|
||||
cache[(buf, st)] = ret = \
|
||||
LazyOp(buf.op, tuple(_recursive_lazyop(x, membufs, var_vals, st, realizes, cache, False, assign_to) for x in buf.srcs), buf.arg)
|
||||
LazyOp(buf.op, tuple(_recursive_lazyop(x, membufs, var_vals, st, realizes, cache, False, assign_to, assign_idx) for x in buf.srcs), buf.arg)
|
||||
return ret
|
||||
|
||||
def _schedule_one(out:LazyBuffer, realizes:Set[LazyBuffer], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> ScheduleItem:
|
||||
|
||||
Reference in New Issue
Block a user