use ast engine for merged reduceop

This commit is contained in:
George Hotz
2022-07-08 09:37:40 -07:00
parent 3656a5615a
commit 8557ed88df

View File

@@ -111,8 +111,7 @@ def _realize_reduceops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer
buf_names : Dict[LazyBuffer, str] = {x:f"arg_{i}" for i,x in enumerate(real_srcs.keys())}
return self.dbuffer(self.shape)._processing_op([(buf_names[lb], db) for lb,db in real_srcs.items()], \
self.dbuffer.code_for_op[self.op.op].replace("A", _ast(src.op, buf_names, self.dbuffer.code_for_op)), \
start=self.dbuffer.start_for_op[self.op.op]), \
_ast(LazyOp(self.op.op, [src.op], self.op.arg), buf_names, self.dbuffer.code_for_op), start=self.dbuffer.start_for_op[self.op.op]), \
list(real_srcs.values()), ReduceOps
else:
real_src = src.realize(self.device)
@@ -136,6 +135,7 @@ def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer
conv_args : Optional[ConvArgs] = None
psrcs = [x for x in real_srcs.keys() if x.optype == ProcessingOps and x.realized is None and len(x.children) <= 1]
if len(psrcs) == 1 and MERGE_ONE_CONV_INTO_ELEMENTWISE:
# TODO: do something similar to what i did with reduceop to use the ast engine?
conv_args = psrcs[0].op.arg
del real_srcs[psrcs[0]]
real_srcs[psrcs[0].op.src[0]], real_srcs[psrcs[0].op.src[1]] = None, None