diff --git a/tinygrad/ast.py b/tinygrad/ast.py index 6658a4a3ea..7e026fd8d8 100644 --- a/tinygrad/ast.py +++ b/tinygrad/ast.py @@ -59,12 +59,9 @@ class ASTKernel: self.ret = output_buffer if output_buffer else type(self.bufs[0])(output_shape if output_shape else self.info.shape, force_create=True) self.bufs = ([type(self.ret)(self.info.shape, hostbuf=self.ret)] if output_shape else [self.ret]) + self.bufs - # TODO: should be optional if it's hitting a function cache - self.processed = False - def process(self) -> None: - if self.processed: return - self.processed = True + if hasattr(self, "sts"): return # already processed + reduceops = [x for x in get_lazyops(self.ast) if x.op in ReduceOps] assert len(dedup(reduceops)) <= 1, "max one reduce op in an ast" self.reduceop = reduceops[0] if reduceops else None