remove Kernel.lazyops [run_process_replay] (#5517)

always use Kernel.ast.lazyops
This commit is contained in:
chenyu
2024-07-16 19:47:42 -04:00
committed by GitHub
parent 1c1d6d3a4a
commit 4ad83d032e
2 changed files with 5 additions and 5 deletions

View File

@@ -69,7 +69,6 @@ class Kernel:
print("INVALID AST")
for op in ast: print_tree(op)
raise e
self.lazyops = self.ast.lazyops
cached_ordered_lazyops: Dict[LazyOp, List[LazyOp]] = {}
def ordered_lazyops(op):
@@ -78,7 +77,7 @@ class Kernel:
self.reduceops = dedup([x for x in ordered_lazyops(self.ast) if x.op in ReduceOps])
self.vars = self.ast.vars()
self.bufs: List[Union[MemBuffer, ConstBuffer]] = dedup([x.arg for x in self.lazyops if x.op in BufferOps])
self.bufs: List[Union[MemBuffer, ConstBuffer]] = dedup([x.arg for x in self.ast.lazyops if x.op in BufferOps])
# get earlybufs, before any reduceops
earlybufs = [x.arg for reduceop in self.reduceops for x in reduceop.lazyops if x.op in BufferOps]
@@ -121,7 +120,7 @@ class Kernel:
ret = type(self).__new__(type(self))
# base linearizer params
ret.opts, ret.ast, ret.lazyops = self.opts, self.ast, self.lazyops
ret.opts, ret.ast = self.opts, self.ast
# things downstream of the AST
ret.reduceops, ret.vars, ret.bufs, ret.full_buf_index = \
@@ -492,7 +491,7 @@ class Kernel:
# ok to pad SUM if all parent ops have f(0) = 0
if self.first_reduce <= axis:
check((r:=cast(LazyOp, self.reduceop)).op is ReduceOps.SUM and \
all(op.op not in UNSAFE_PAD_OPS for ops in r.src for op in ops.lazyops), "cannot pad")
all(op.op not in UNSAFE_PAD_OPS for sop in r.src for op in sop.lazyops), "cannot pad")
padded = False
for i,st in enumerate(self.sts):
if self.sts[i].shape[axis] == 1: continue # reduced
@@ -645,7 +644,7 @@ class Kernel:
@functools.cached_property
def name(self) -> str:
# kernel name (before late upcast)
name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.lazyops) else "E")) + \
name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.ast.lazyops) else "E")) + \
(f"{len(self.ast.src)}_" if len(self.ast.src) > 1 else "_") + \
colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])

View File

@@ -75,6 +75,7 @@ class IndependentLowerer:
# NOTE: assumes the shape is <global dims> <local dims> <group_for_reduces> <reduces> <upcasts/unrolls>
full_shape = ast.full_shape
first_upcasted = len(full_shape)-ki.upcasted
# if there's no reduce, this is first_upcasted
first_reduce = [x!=y for x,y in zip(ast.src[0].arg.st.shape[:first_upcasted]+(0,), full_shape[:first_upcasted]+(1,))].index(True)
local_loads = [x for x in ast.lazyops if x.op is BufferOps.LOAD and x.arg.idx == -1]
# NOTE: this is taking the first one...there may be subtlelies here with multireduces