mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
sorry about the line count, this is a good optimization
This commit is contained in:
@@ -6,6 +6,7 @@ import sys, functools, operator
|
||||
from tinygrad.helpers import ConvArgs
|
||||
from tinygrad.shapetracker import ShapeTracker
|
||||
|
||||
# these are the llops your accelerator must implement
|
||||
UnaryOps = Enum("UnaryOps", ["NOOP", "NEG", "RELU", "EXP", "LOG", "SIGN"])
|
||||
BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"])
|
||||
ReduceOps = Enum("ReduceOps", ["SUM", "MAX"])
|
||||
@@ -22,8 +23,10 @@ DeviceBuffer = Any
|
||||
# -O1
|
||||
MERGE_MOVEMENT_OPS = True
|
||||
REMOVE_MOVEMENT_NOPS = True
|
||||
MERGE_UNARY_OPS = True
|
||||
|
||||
# -O2
|
||||
MERGE_ELEMENTWISE_OPS = False # this is making training very slow in the tests
|
||||
SHUFFLE_MOVEMENT_OPS = False # this is making training very slow in the tests
|
||||
|
||||
# -O3
|
||||
@@ -99,6 +102,7 @@ class Device:
|
||||
print(op, "not available", e)
|
||||
DEFAULT = "CPU" if DEFAULT is None else DEFAULT
|
||||
|
||||
# TODO: make a _realize function for each type called by realize
|
||||
def _realize(self:LazyBuffer) -> DeviceBuffer:
|
||||
if self.optype == LoadOps and self.op.op == LoadOps.FROMCPU:
|
||||
return Device._buffers[self.device].fromCPU(self.op.arg), []
|
||||
@@ -112,9 +116,13 @@ def _realize(self:LazyBuffer) -> DeviceBuffer:
|
||||
else:
|
||||
return functools.reduce(lambda x,o: x.movement_op(o.op, o.arg), get_lazyops(self.op)[::-1], real_src), [real_src]
|
||||
elif self.optype == BinaryOps:
|
||||
real_srcs = [x.realize() for x in self.op.src]
|
||||
if len(real_srcs) == 1: return real_srcs[0].unary_op(self.op.op), real_srcs
|
||||
else: return real_srcs[0].binary_op(self.op.op, real_srcs[1]), real_srcs
|
||||
real_srcs : Dict[LazyBuffer, DeviceBuffer] = {}
|
||||
[real_srcs.setdefault(x,x.realize()) for x in get_lazybuffers(self.op) if x not in real_srcs]
|
||||
def ast_eval(x: Union[LazyBuffer, LazyOp]) -> DeviceBuffer:
|
||||
if isinstance(x, LazyBuffer): return real_srcs[x]
|
||||
if isinstance(x.op, UnaryOps): return ast_eval(x.src[0]).unary_op(x.op)
|
||||
if isinstance(x.op, BinaryOps): return ast_eval(x.src[0]).binary_op(x.op, ast_eval(x.src[1]))
|
||||
return ast_eval(self.op), list(real_srcs.values())
|
||||
elif self.optype == ProcessingOps:
|
||||
real_src_x = self.op.src[0].realize(self.device)
|
||||
real_src_w = self.op.src[1].realize(self.device)
|
||||
@@ -199,4 +207,10 @@ class LazyBuffer:
|
||||
return LazyBuffer(x.device, C.out_shape, ProcessingOps, LazyOp(op, (x, w), C))
|
||||
|
||||
def elementwise_op(op:Union[UnaryOps, BinaryOps], srcs:Tuple[LazyBuffer, ...]) -> LazyBuffer:
|
||||
return LazyBuffer(srcs[0].device, srcs[0].shape, BinaryOps, LazyOp(op, srcs))
|
||||
out_device, out_shape = srcs[0].device, srcs[0].shape
|
||||
|
||||
if (MERGE_UNARY_OPS and len(srcs) == 1) or MERGE_ELEMENTWISE_OPS:
|
||||
# remove the buffers from any BinaryOps that feed into this
|
||||
srcs = tuple(x.op if x.optype == BinaryOps and x.realized is None else x for x in srcs) # type: ignore
|
||||
|
||||
return LazyBuffer(out_device, out_shape, BinaryOps, LazyOp(op, srcs))
|
||||
|
||||
Reference in New Issue
Block a user