sorry about the line count, this is a good optimization

This commit is contained in:
George Hotz
2022-07-03 17:11:13 -07:00
parent 748618530b
commit 6b0aa2a902

View File

@@ -6,6 +6,7 @@ import sys, functools, operator
from tinygrad.helpers import ConvArgs
from tinygrad.shapetracker import ShapeTracker
# these are the llops your accelerator must implement
UnaryOps = Enum("UnaryOps", ["NOOP", "NEG", "RELU", "EXP", "LOG", "SIGN"])
BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"])
ReduceOps = Enum("ReduceOps", ["SUM", "MAX"])
@@ -22,8 +23,10 @@ DeviceBuffer = Any
# -O1
MERGE_MOVEMENT_OPS = True
REMOVE_MOVEMENT_NOPS = True
MERGE_UNARY_OPS = True
# -O2
MERGE_ELEMENTWISE_OPS = False # this is making training very slow in the tests
SHUFFLE_MOVEMENT_OPS = False # this is making training very slow in the tests
# -O3
@@ -99,6 +102,7 @@ class Device:
print(op, "not available", e)
DEFAULT = "CPU" if DEFAULT is None else DEFAULT
# TODO: make a _realize function for each type called by realize
def _realize(self:LazyBuffer) -> DeviceBuffer:
if self.optype == LoadOps and self.op.op == LoadOps.FROMCPU:
return Device._buffers[self.device].fromCPU(self.op.arg), []
@@ -112,9 +116,13 @@ def _realize(self:LazyBuffer) -> DeviceBuffer:
else:
return functools.reduce(lambda x,o: x.movement_op(o.op, o.arg), get_lazyops(self.op)[::-1], real_src), [real_src]
elif self.optype == BinaryOps:
real_srcs = [x.realize() for x in self.op.src]
if len(real_srcs) == 1: return real_srcs[0].unary_op(self.op.op), real_srcs
else: return real_srcs[0].binary_op(self.op.op, real_srcs[1]), real_srcs
real_srcs : Dict[LazyBuffer, DeviceBuffer] = {}
[real_srcs.setdefault(x,x.realize()) for x in get_lazybuffers(self.op) if x not in real_srcs]
def ast_eval(x: Union[LazyBuffer, LazyOp]) -> DeviceBuffer:
if isinstance(x, LazyBuffer): return real_srcs[x]
if isinstance(x.op, UnaryOps): return ast_eval(x.src[0]).unary_op(x.op)
if isinstance(x.op, BinaryOps): return ast_eval(x.src[0]).binary_op(x.op, ast_eval(x.src[1]))
return ast_eval(self.op), list(real_srcs.values())
elif self.optype == ProcessingOps:
real_src_x = self.op.src[0].realize(self.device)
real_src_w = self.op.src[1].realize(self.device)
@@ -199,4 +207,10 @@ class LazyBuffer:
return LazyBuffer(x.device, C.out_shape, ProcessingOps, LazyOp(op, (x, w), C))
def elementwise_op(op:Union[UnaryOps, BinaryOps], srcs:Tuple[LazyBuffer, ...]) -> LazyBuffer:
return LazyBuffer(srcs[0].device, srcs[0].shape, BinaryOps, LazyOp(op, srcs))
out_device, out_shape = srcs[0].device, srcs[0].shape
if (MERGE_UNARY_OPS and len(srcs) == 1) or MERGE_ELEMENTWISE_OPS:
# remove the buffers from any BinaryOps that feed into this
srcs = tuple(x.op if x.optype == BinaryOps and x.realized is None else x for x in srcs) # type: ignore
return LazyBuffer(out_device, out_shape, BinaryOps, LazyOp(op, srcs))