From 289fd2e940145056bb98bf0643dd4627f0ddd705 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 11 Jul 2024 10:56:53 +0300 Subject: [PATCH] Lowerer cleanup 2 [run_process_replay] (#5376) * test outbufs delete * comments * valid is bool --- tinygrad/codegen/kernel.py | 11 +++++------ tinygrad/codegen/lowerer.py | 7 +++---- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 2f8605c978..6597a20714 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -71,9 +71,8 @@ class Kernel: return cached_ordered_lazyops[op] self.reduceops = dedup([x for out in self.ast for x in ordered_lazyops(out) if x.op in ReduceOps]) - self.outbufs, self.vars = [x.arg for x in self.ast], flatten([x.vars() for x in self.ast]) - loadops = [BufferOps.LOAD, BufferOps.CONST] - self.bufs: List[Union[MemBuffer, ConstBuffer]] = self.outbufs + dedup([x.arg for x in self.lazyops if x.op in loadops]) + self.vars = flatten([x.vars() for x in self.ast]) + self.bufs: List[Union[MemBuffer, ConstBuffer]] = dedup([x.arg for x in self.lazyops if x.op in BufferOps]) # get earlybufs, before any reduceops self.earlybufs = [x.arg for reduceop in self.reduceops for x in reduceop.lazyops if x.op in BufferOps] @@ -112,8 +111,8 @@ class Kernel: ret.opts, ret.ast, ret.lazyops = self.opts, self.ast, self.lazyops # things downstream of the AST - ret.reduceops, ret.outbufs, ret.vars, ret.bufs, ret.earlybufs, ret.full_buf_index = \ - self.reduceops, self.outbufs, self.vars, self.bufs, self.earlybufs, self.full_buf_index + ret.reduceops, ret.vars, ret.bufs, ret.earlybufs, ret.full_buf_index = \ + self.reduceops, self.vars, self.bufs, self.earlybufs, self.full_buf_index ret.sts = self.sts[:len(ret.bufs)] # NOTE: must redo the local buffers with TC in beam # parameters for optimizations @@ -632,7 +631,7 @@ class Kernel: def name(self) -> str: # kernel name (before late upcast) name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.lazyops) else "E")) + \ - (f"{len(self.outbufs)}_" if len(self.outbufs) > 1 else "_") + \ + (f"{len(self.ast)}_" if len(self.ast) > 1 else "_") + \ colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())]) # name the function something unique diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index 96c679492c..805f399019 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -54,20 +54,19 @@ class Lowerer(Kernel): if x.op in BufferOps: idx, valid = st_to_uops(x.arg.st, self.ridxs if x.op is BufferOps.LOAD and x.arg.idx == -1 else self.idxs) # TODO: check has_valid in UPat, not here - has_valid = valid.op is not UOps.CONST or (valid.arg is not True and valid.arg != 1) + has_valid = valid.op is not UOps.CONST or valid.arg is not True if x.op is BufferOps.CONST: dtype = x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype return UOp.alu(TernaryOps.WHERE, valid, UOp.const(dtype, x.arg.val), UOp.const(dtype, 0)) if x.arg.idx == -1: buf = UOp(UOps.DEFINE_LOCAL, PtrDType(x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype), (), ("temp", x.arg.st.size)) else: - # NOTE: outbufs is quickly findable in AST buf = UOp(UOps.DEFINE_GLOBAL, x.arg.dtype if isinstance(x.arg.dtype, ImageDType) else PtrDType(x.arg.dtype), (), - (x.arg.idx, any(x.arg.idx == y.idx for y in self.outbufs))) + (x.arg.idx, any(x.arg.idx == y.arg.idx for y in self.ast))) if x.op is BufferOps.LOAD: barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[0]),)),) if len(x.src) else () return UOp(UOps.LOAD, x.arg.dtype.scalar(), (buf, idx) + ((valid, UOp.const(x.arg.dtype.scalar(), 0)) if has_valid else ()) + barrier) - # TODO: what is this? + # NOTE: only store the local reduceop in the first thread if self.group_for_reduces > 0 and x.arg.idx != -1: valid, has_valid = valid * self.idxs[self.first_reduce].eq(0), True return UOp(UOps.STORE, None, (buf, idx, self.to_uop(x.src[0])) + ((valid,) if has_valid else ()))