only childless elementwise ops get merged

This commit is contained in:
George Hotz
2022-07-07 11:13:25 -07:00
parent 04e7e4104c
commit eb6696c3a5

View File

@@ -254,7 +254,7 @@ def elementwise_op(op:Union[UnaryOps, BinaryOps], *srcs:LazyBuffer) -> LazyBuffe
out_device, out_shape = srcs[0].device, srcs[0].shape
if (MERGE_UNARY_OPS and len(srcs) == 1) or MERGE_ELEMENTWISE_OPS:
# remove the buffers from any BinaryOps that feed into this
srcs = tuple(x.op if x.optype == BinaryOps and x.realized is None else x for x in srcs) # type: ignore
# remove the buffers from any (childless) BinaryOps that feed into this
srcs = tuple(x.op if x.optype == BinaryOps and len(x.children) == 0 and x.realized is None else x for x in srcs) # type: ignore
return LazyBuffer(out_device, out_shape, BinaryOps, LazyOp(op, srcs))