Lowerer cleanup 2 [run_process_replay] (#5376)

* test outbufs delete

* comments

* valid is bool
This commit is contained in:
qazal
2024-07-11 10:56:53 +03:00
committed by GitHub
parent 9ca2d96b6b
commit 289fd2e940
2 changed files with 8 additions and 10 deletions

View File

@@ -71,9 +71,8 @@ class Kernel:
return cached_ordered_lazyops[op]
self.reduceops = dedup([x for out in self.ast for x in ordered_lazyops(out) if x.op in ReduceOps])
self.outbufs, self.vars = [x.arg for x in self.ast], flatten([x.vars() for x in self.ast])
loadops = [BufferOps.LOAD, BufferOps.CONST]
self.bufs: List[Union[MemBuffer, ConstBuffer]] = self.outbufs + dedup([x.arg for x in self.lazyops if x.op in loadops])
self.vars = flatten([x.vars() for x in self.ast])
self.bufs: List[Union[MemBuffer, ConstBuffer]] = dedup([x.arg for x in self.lazyops if x.op in BufferOps])
# get earlybufs, before any reduceops
self.earlybufs = [x.arg for reduceop in self.reduceops for x in reduceop.lazyops if x.op in BufferOps]
@@ -112,8 +111,8 @@ class Kernel:
ret.opts, ret.ast, ret.lazyops = self.opts, self.ast, self.lazyops
# things downstream of the AST
ret.reduceops, ret.outbufs, ret.vars, ret.bufs, ret.earlybufs, ret.full_buf_index = \
self.reduceops, self.outbufs, self.vars, self.bufs, self.earlybufs, self.full_buf_index
ret.reduceops, ret.vars, ret.bufs, ret.earlybufs, ret.full_buf_index = \
self.reduceops, self.vars, self.bufs, self.earlybufs, self.full_buf_index
ret.sts = self.sts[:len(ret.bufs)] # NOTE: must redo the local buffers with TC in beam
# parameters for optimizations
@@ -632,7 +631,7 @@ class Kernel:
def name(self) -> str:
# kernel name (before late upcast)
name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.lazyops) else "E")) + \
(f"{len(self.outbufs)}_" if len(self.outbufs) > 1 else "_") + \
(f"{len(self.ast)}_" if len(self.ast) > 1 else "_") + \
colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
# name the function something unique

View File

@@ -54,20 +54,19 @@ class Lowerer(Kernel):
if x.op in BufferOps:
idx, valid = st_to_uops(x.arg.st, self.ridxs if x.op is BufferOps.LOAD and x.arg.idx == -1 else self.idxs)
# TODO: check has_valid in UPat, not here
has_valid = valid.op is not UOps.CONST or (valid.arg is not True and valid.arg != 1)
has_valid = valid.op is not UOps.CONST or valid.arg is not True
if x.op is BufferOps.CONST:
dtype = x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype
return UOp.alu(TernaryOps.WHERE, valid, UOp.const(dtype, x.arg.val), UOp.const(dtype, 0))
if x.arg.idx == -1:
buf = UOp(UOps.DEFINE_LOCAL, PtrDType(x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype), (), ("temp", x.arg.st.size))
else:
# NOTE: outbufs is quickly findable in AST
buf = UOp(UOps.DEFINE_GLOBAL, x.arg.dtype if isinstance(x.arg.dtype, ImageDType) else PtrDType(x.arg.dtype), (),
(x.arg.idx, any(x.arg.idx == y.idx for y in self.outbufs)))
(x.arg.idx, any(x.arg.idx == y.arg.idx for y in self.ast)))
if x.op is BufferOps.LOAD:
barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[0]),)),) if len(x.src) else ()
return UOp(UOps.LOAD, x.arg.dtype.scalar(), (buf, idx) + ((valid, UOp.const(x.arg.dtype.scalar(), 0)) if has_valid else ()) + barrier)
# TODO: what is this?
# NOTE: only store the local reduceop in the first thread
if self.group_for_reduces > 0 and x.arg.idx != -1: valid, has_valid = valid * self.idxs[self.first_reduce].eq(0), True
return UOp(UOps.STORE, None, (buf, idx, self.to_uop(x.src[0])) + ((valid,) if has_valid else ()))