diff --git a/accel/llvm/ops_llvm.py b/accel/llvm/ops_llvm.py index 0e42f7b843..2b6c0f2a9f 100644 --- a/accel/llvm/ops_llvm.py +++ b/accel/llvm/ops_llvm.py @@ -153,8 +153,7 @@ class LLVMBuffer(ExplicitExecAST): BinaryOps.MUL: lambda builder,x,y: builder.fmul(x,y), BinaryOps.DIV: lambda builder,x,y: builder.fdiv(x,y), BinaryOps.POW: lambda builder,x,y: builder.call(builder._block.module.declare_intrinsic('llvm.pow', [ir.FloatType()]), [x,y]), - BinaryOps.CMPEQ: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("==", x, y), ir.FloatType()), - MovementOps.RESHAPE: lambda builder,x: x, + BinaryOps.CMPEQ: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("==", x, y), ir.FloatType()) } def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], hostbuf=None): super().__init__(shape, hostbuf) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 960e4cb691..ef9bd2a051 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -67,28 +67,31 @@ def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer real_srcs : Dict[LazyBuffer, Union[None, LazyOp, DeviceBuffer]] = {x:None for x in get_buffers(self.op)} op_type : OpType = BinaryOps psrcs : List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype in [ProcessingOps,ReduceOps] and x.realized is None and len(x.children) <= 1 and len(k.children) <= 1] + intermediate_shape = self.shape if len(psrcs) == 1 and MERGE_ONE_REDUCE_INTO_ELEMENTWISE and (self.device != "OPENCL" or self.shape[-1] == 4): if psrcs[0][1].optype == ProcessingOps: real_srcs[psrcs[0][0]] = psrcs[0][1].op - real_srcs[psrcs[0][1].op.src[0]], real_srcs[psrcs[0][1].op.src[1]] = None, None + for x in psrcs[0][1].op.src: + real_srcs[x] = x.realize(self.device) op_type = ProcessingOps elif psrcs[0][1].optype == ReduceOps: src = psrcs[0][1].op.src[0] if MERGE_ELEMENTWISE_INTO_REDUCE and src.realized is None and src.optype == BinaryOps and len(src.children) <= 1: src = src.op for x in (get_buffers(src) if isinstance(src, LazyOp) else [src]): - real_srcs[x] = None + real_srcs[x] = x.realize(self.device) real_srcs[psrcs[0][0]] = LazyOp(psrcs[0][1].op.op, (src,), psrcs[0][1].op.arg) op_type = ReduceOps - # a reshape is allowed after the ReduceOp, it's a nop in the backend + # if the ReduceOp is followed by a reshape, we push this reshape before all the ElementwiseOp inputs if psrcs[0][0].shape != psrcs[0][1].shape: - real_srcs[psrcs[0][0]] = LazyOp(MovementOps.RESHAPE, (real_srcs[psrcs[0][0]],), psrcs[0][0].shape) + intermediate_shape = psrcs[0][1].shape + assert psrcs[0][0].shape == self.shape, f"shape mismatch {psrcs[0][0].shape} != {self.shape}" + # NOTE: these RESHAPEs will return self if they don't change the shape for x in real_srcs.keys(): if real_srcs[x] is None: - real_srcs[x] = x.realize(self.device) + real_srcs[x] = x.movement_op(MovementOps.RESHAPE, intermediate_shape).realize(self.device) ret = self.dbuffer.exec_ast(realize_buffers(real_srcs, self.op)) - assert ret.shape == self.shape, f"shape mismatch {ret.shape} != {self.shape}" - return ret, [x for x in real_srcs.values() if not isinstance(x, LazyOp) and x is not None], op_type + return ret.movement_op(MovementOps.RESHAPE, self.shape), [x for x in real_srcs.values() if not isinstance(x, LazyOp) and x is not None], op_type _realize = {LoadOps:_realize_loadops, ReduceOps:_realize_reduceops, MovementOps:_realize_movementops, BinaryOps:_realize_binaryops, ProcessingOps:_realize_processingops} diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index ff2bd81ee2..d48def3775 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -82,7 +82,7 @@ class GPUBuffer(ExplicitExecAST): UnaryOps.EXP: "exp(A)", UnaryOps.LOG: "log(A)", UnaryOps.SIGN: "sign(A)", UnaryOps.RECIPROCAL: "((float)1.0/A)", BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)", - ReduceOps.SUM: "(acc + A)", ReduceOps.MAX: "max(A, acc)", MovementOps.RESHAPE: "(A)" + ReduceOps.SUM: "(acc + A)", ReduceOps.MAX: "max(A, acc)" } start_for_op = {ReduceOps.SUM: "0.0", ReduceOps.MAX: "-INFINITY"}