diff --git a/tinygrad/ops.py b/tinygrad/ops.py index e4c732abab..cdee1dd993 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -6,6 +6,7 @@ import sys, functools, operator from tinygrad.helpers import ConvArgs from tinygrad.shapetracker import ShapeTracker +# these are the llops your accelerator must implement UnaryOps = Enum("UnaryOps", ["NOOP", "NEG", "RELU", "EXP", "LOG", "SIGN"]) BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"]) ReduceOps = Enum("ReduceOps", ["SUM", "MAX"]) @@ -22,8 +23,10 @@ DeviceBuffer = Any # -O1 MERGE_MOVEMENT_OPS = True REMOVE_MOVEMENT_NOPS = True +MERGE_UNARY_OPS = True # -O2 +MERGE_ELEMENTWISE_OPS = False # this is making training very slow in the tests SHUFFLE_MOVEMENT_OPS = False # this is making training very slow in the tests # -O3 @@ -99,6 +102,7 @@ class Device: print(op, "not available", e) DEFAULT = "CPU" if DEFAULT is None else DEFAULT +# TODO: make a _realize function for each type called by realize def _realize(self:LazyBuffer) -> DeviceBuffer: if self.optype == LoadOps and self.op.op == LoadOps.FROMCPU: return Device._buffers[self.device].fromCPU(self.op.arg), [] @@ -112,9 +116,13 @@ def _realize(self:LazyBuffer) -> DeviceBuffer: else: return functools.reduce(lambda x,o: x.movement_op(o.op, o.arg), get_lazyops(self.op)[::-1], real_src), [real_src] elif self.optype == BinaryOps: - real_srcs = [x.realize() for x in self.op.src] - if len(real_srcs) == 1: return real_srcs[0].unary_op(self.op.op), real_srcs - else: return real_srcs[0].binary_op(self.op.op, real_srcs[1]), real_srcs + real_srcs : Dict[LazyBuffer, DeviceBuffer] = {} + [real_srcs.setdefault(x,x.realize()) for x in get_lazybuffers(self.op) if x not in real_srcs] + def ast_eval(x: Union[LazyBuffer, LazyOp]) -> DeviceBuffer: + if isinstance(x, LazyBuffer): return real_srcs[x] + if isinstance(x.op, UnaryOps): return ast_eval(x.src[0]).unary_op(x.op) + if isinstance(x.op, BinaryOps): return ast_eval(x.src[0]).binary_op(x.op, ast_eval(x.src[1])) + return ast_eval(self.op), list(real_srcs.values()) elif self.optype == ProcessingOps: real_src_x = self.op.src[0].realize(self.device) real_src_w = self.op.src[1].realize(self.device) @@ -199,4 +207,10 @@ class LazyBuffer: return LazyBuffer(x.device, C.out_shape, ProcessingOps, LazyOp(op, (x, w), C)) def elementwise_op(op:Union[UnaryOps, BinaryOps], srcs:Tuple[LazyBuffer, ...]) -> LazyBuffer: - return LazyBuffer(srcs[0].device, srcs[0].shape, BinaryOps, LazyOp(op, srcs)) + out_device, out_shape = srcs[0].device, srcs[0].shape + + if (MERGE_UNARY_OPS and len(srcs) == 1) or MERGE_ELEMENTWISE_OPS: + # remove the buffers from any BinaryOps that feed into this + srcs = tuple(x.op if x.optype == BinaryOps and x.realized is None else x for x in srcs) # type: ignore + + return LazyBuffer(out_device, out_shape, BinaryOps, LazyOp(op, srcs))