mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 07:18:40 -05:00
to_uop split paths part 2 [pr] (#7746)
This commit is contained in:
@@ -68,25 +68,26 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children:DefaultDict[UOp, Dict[U
|
||||
if buf.op is Ops.CONST:
|
||||
if isinstance(val:=buf.arg, UOp): ctx.var_vals.update([val.unbind()])
|
||||
return UOp(Ops.VALID, dtypes.bool, (buf.st.to_uop(),)).where(UOp.const(dtype, val), 0)
|
||||
# everything else has BUFFER
|
||||
ubuf = ctx.buf_uops.setdefault(b:=buf.buffer, UOp.new_buffer(b.device, b.size, b.dtype, num=len(ctx.buf_uops)))
|
||||
# if the buffer is already realized we just load it
|
||||
if buf.is_realized(): return UOp(Ops.VIEW, dtype, (ubuf,), buf.st)
|
||||
# everything else needs sources
|
||||
src = tuple(to_uop(x, ctx, children, allbufs, double_reduces, cache) for x in buf.srcs)
|
||||
if buf.op is Ops.ASSIGN:
|
||||
ctx.assigns.add(ubuf)
|
||||
ret = UOp(Ops.ASSIGN, dtype, (ubuf, src[1]), buf.arg)
|
||||
else: ret = UOp(cast(Ops, buf.op), dtype, src, None if buf.op in {Ops.CAST, Ops.BITCAST} else buf.arg)
|
||||
if buf.forced_realize: ret = UOp(Ops.CONTIGUOUS, dtype, (ret,))
|
||||
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf, ret), buf.st)
|
||||
if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata
|
||||
ctx.lazybufs[b] = buf
|
||||
# things for fuse.py
|
||||
allbufs[ubuf] = ret
|
||||
if buf.op is Ops.REDUCE_AXIS and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[ubuf] = None
|
||||
for x in src:
|
||||
if is_scheduled(x.base): children[x.base.buf_uop][ubuf] = None
|
||||
# everything else is a VIEW of BUFFER (with an optional op)
|
||||
if buf.is_realized():
|
||||
ctx.buf_uops[buf.buffer] = ubuf = UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(ctx.buf_uops))
|
||||
op = None
|
||||
elif buf.op is Ops.ASSIGN:
|
||||
target, new_val = [to_uop(x, ctx, children, allbufs, double_reduces, cache) for x in buf.srcs]
|
||||
ctx.assigns.add(ubuf:=target.buf_uop)
|
||||
op = UOp(Ops.ASSIGN, dtype, (ubuf, new_val), buf.arg)
|
||||
else:
|
||||
ctx.buf_uops[buf.buffer] = ubuf = UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(ctx.buf_uops))
|
||||
op = UOp(cast(Ops, buf.op), dtype, tuple(to_uop(x, ctx, children, allbufs, double_reduces, cache) for x in buf.srcs),
|
||||
None if buf.op in {Ops.CAST, Ops.BITCAST} else buf.arg)
|
||||
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf,) if op is None else (ubuf, op.contiguous() if buf.forced_realize else op), buf.st)
|
||||
if op is not None:
|
||||
if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata
|
||||
ctx.lazybufs[buf.buffer] = buf
|
||||
for x in op.src:
|
||||
if is_scheduled(x.base): children[x.base.buf_uop][ubuf] = None
|
||||
allbufs[ubuf] = ret
|
||||
if buf.op is Ops.REDUCE_AXIS and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[ubuf] = None
|
||||
return ret
|
||||
|
||||
# **** AST graph rewrite
|
||||
|
||||
@@ -336,6 +336,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=(idx, False))
|
||||
def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
|
||||
def contiguous(self): return UOp(Ops.CONTIGUOUS, self.dtype, (self,))
|
||||
@property
|
||||
def is_contiguous_base(self): return self.op is Ops.CONTIGUOUS and not (self.src[0].base.op is Ops.VIEW and len(self.src[0].base.src) == 2)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user