simpler processed check

This commit is contained in:
George Hotz
2023-02-10 22:49:20 -06:00
parent 609477656e
commit 1fb5b8069b

View File

@@ -59,12 +59,9 @@ class ASTKernel:
self.ret = output_buffer if output_buffer else type(self.bufs[0])(output_shape if output_shape else self.info.shape, force_create=True)
self.bufs = ([type(self.ret)(self.info.shape, hostbuf=self.ret)] if output_shape else [self.ret]) + self.bufs
# TODO: should be optional if it's hitting a function cache
self.processed = False
def process(self) -> None:
if self.processed: return
self.processed = True
if hasattr(self, "sts"): return # already processed
reduceops = [x for x in get_lazyops(self.ast) if x.op in ReduceOps]
assert len(dedup(reduceops)) <= 1, "max one reduce op in an ast"
self.reduceop = reduceops[0] if reduceops else None