mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-27 15:58:10 -05:00
Lowerer cleanup 2 [run_process_replay] (#5376)
* test outbufs delete * comments * valid is bool
This commit is contained in:
@@ -71,9 +71,8 @@ class Kernel:
|
||||
return cached_ordered_lazyops[op]
|
||||
self.reduceops = dedup([x for out in self.ast for x in ordered_lazyops(out) if x.op in ReduceOps])
|
||||
|
||||
self.outbufs, self.vars = [x.arg for x in self.ast], flatten([x.vars() for x in self.ast])
|
||||
loadops = [BufferOps.LOAD, BufferOps.CONST]
|
||||
self.bufs: List[Union[MemBuffer, ConstBuffer]] = self.outbufs + dedup([x.arg for x in self.lazyops if x.op in loadops])
|
||||
self.vars = flatten([x.vars() for x in self.ast])
|
||||
self.bufs: List[Union[MemBuffer, ConstBuffer]] = dedup([x.arg for x in self.lazyops if x.op in BufferOps])
|
||||
|
||||
# get earlybufs, before any reduceops
|
||||
self.earlybufs = [x.arg for reduceop in self.reduceops for x in reduceop.lazyops if x.op in BufferOps]
|
||||
@@ -112,8 +111,8 @@ class Kernel:
|
||||
ret.opts, ret.ast, ret.lazyops = self.opts, self.ast, self.lazyops
|
||||
|
||||
# things downstream of the AST
|
||||
ret.reduceops, ret.outbufs, ret.vars, ret.bufs, ret.earlybufs, ret.full_buf_index = \
|
||||
self.reduceops, self.outbufs, self.vars, self.bufs, self.earlybufs, self.full_buf_index
|
||||
ret.reduceops, ret.vars, ret.bufs, ret.earlybufs, ret.full_buf_index = \
|
||||
self.reduceops, self.vars, self.bufs, self.earlybufs, self.full_buf_index
|
||||
ret.sts = self.sts[:len(ret.bufs)] # NOTE: must redo the local buffers with TC in beam
|
||||
|
||||
# parameters for optimizations
|
||||
@@ -632,7 +631,7 @@ class Kernel:
|
||||
def name(self) -> str:
|
||||
# kernel name (before late upcast)
|
||||
name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.lazyops) else "E")) + \
|
||||
(f"{len(self.outbufs)}_" if len(self.outbufs) > 1 else "_") + \
|
||||
(f"{len(self.ast)}_" if len(self.ast) > 1 else "_") + \
|
||||
colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
|
||||
|
||||
# name the function something unique
|
||||
|
||||
@@ -54,20 +54,19 @@ class Lowerer(Kernel):
|
||||
if x.op in BufferOps:
|
||||
idx, valid = st_to_uops(x.arg.st, self.ridxs if x.op is BufferOps.LOAD and x.arg.idx == -1 else self.idxs)
|
||||
# TODO: check has_valid in UPat, not here
|
||||
has_valid = valid.op is not UOps.CONST or (valid.arg is not True and valid.arg != 1)
|
||||
has_valid = valid.op is not UOps.CONST or valid.arg is not True
|
||||
if x.op is BufferOps.CONST:
|
||||
dtype = x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype
|
||||
return UOp.alu(TernaryOps.WHERE, valid, UOp.const(dtype, x.arg.val), UOp.const(dtype, 0))
|
||||
if x.arg.idx == -1:
|
||||
buf = UOp(UOps.DEFINE_LOCAL, PtrDType(x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype), (), ("temp", x.arg.st.size))
|
||||
else:
|
||||
# NOTE: outbufs is quickly findable in AST
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, x.arg.dtype if isinstance(x.arg.dtype, ImageDType) else PtrDType(x.arg.dtype), (),
|
||||
(x.arg.idx, any(x.arg.idx == y.idx for y in self.outbufs)))
|
||||
(x.arg.idx, any(x.arg.idx == y.arg.idx for y in self.ast)))
|
||||
if x.op is BufferOps.LOAD:
|
||||
barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[0]),)),) if len(x.src) else ()
|
||||
return UOp(UOps.LOAD, x.arg.dtype.scalar(), (buf, idx) + ((valid, UOp.const(x.arg.dtype.scalar(), 0)) if has_valid else ()) + barrier)
|
||||
# TODO: what is this?
|
||||
# NOTE: only store the local reduceop in the first thread
|
||||
if self.group_for_reduces > 0 and x.arg.idx != -1: valid, has_valid = valid * self.idxs[self.first_reduce].eq(0), True
|
||||
return UOp(UOps.STORE, None, (buf, idx, self.to_uop(x.src[0])) + ((valid,) if has_valid else ()))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user