to_uop split paths part 2 [pr] (#7746)

This commit is contained in:
qazal
2024-11-17 15:07:28 +02:00
committed by GitHub
parent 99024b922b
commit 40642cb9ea
2 changed files with 21 additions and 19 deletions

View File

@@ -68,25 +68,26 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children:DefaultDict[UOp, Dict[U
if buf.op is Ops.CONST:
if isinstance(val:=buf.arg, UOp): ctx.var_vals.update([val.unbind()])
return UOp(Ops.VALID, dtypes.bool, (buf.st.to_uop(),)).where(UOp.const(dtype, val), 0)
# everything else has BUFFER
ubuf = ctx.buf_uops.setdefault(b:=buf.buffer, UOp.new_buffer(b.device, b.size, b.dtype, num=len(ctx.buf_uops)))
# if the buffer is already realized we just load it
if buf.is_realized(): return UOp(Ops.VIEW, dtype, (ubuf,), buf.st)
# everything else needs sources
src = tuple(to_uop(x, ctx, children, allbufs, double_reduces, cache) for x in buf.srcs)
if buf.op is Ops.ASSIGN:
ctx.assigns.add(ubuf)
ret = UOp(Ops.ASSIGN, dtype, (ubuf, src[1]), buf.arg)
else: ret = UOp(cast(Ops, buf.op), dtype, src, None if buf.op in {Ops.CAST, Ops.BITCAST} else buf.arg)
if buf.forced_realize: ret = UOp(Ops.CONTIGUOUS, dtype, (ret,))
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf, ret), buf.st)
if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata
ctx.lazybufs[b] = buf
# things for fuse.py
allbufs[ubuf] = ret
if buf.op is Ops.REDUCE_AXIS and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[ubuf] = None
for x in src:
if is_scheduled(x.base): children[x.base.buf_uop][ubuf] = None
# everything else is a VIEW of BUFFER (with an optional op)
if buf.is_realized():
ctx.buf_uops[buf.buffer] = ubuf = UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(ctx.buf_uops))
op = None
elif buf.op is Ops.ASSIGN:
target, new_val = [to_uop(x, ctx, children, allbufs, double_reduces, cache) for x in buf.srcs]
ctx.assigns.add(ubuf:=target.buf_uop)
op = UOp(Ops.ASSIGN, dtype, (ubuf, new_val), buf.arg)
else:
ctx.buf_uops[buf.buffer] = ubuf = UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(ctx.buf_uops))
op = UOp(cast(Ops, buf.op), dtype, tuple(to_uop(x, ctx, children, allbufs, double_reduces, cache) for x in buf.srcs),
None if buf.op in {Ops.CAST, Ops.BITCAST} else buf.arg)
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf,) if op is None else (ubuf, op.contiguous() if buf.forced_realize else op), buf.st)
if op is not None:
if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata
ctx.lazybufs[buf.buffer] = buf
for x in op.src:
if is_scheduled(x.base): children[x.base.buf_uop][ubuf] = None
allbufs[ubuf] = ret
if buf.op is Ops.REDUCE_AXIS and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[ubuf] = None
return ret
# **** AST graph rewrite

View File

@@ -336,6 +336,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=(idx, False))
def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
def contiguous(self): return UOp(Ops.CONTIGUOUS, self.dtype, (self,))
@property
def is_contiguous_base(self): return self.op is Ops.CONTIGUOUS and not (self.src[0].base.op is Ops.VIEW and len(self.src[0].base.src) == 2)