diff --git a/docs/DESIGNv2.md b/docs/DESIGNv2.md new file mode 100644 index 0000000000..5dda08e36b --- /dev/null +++ b/docs/DESIGNv2.md @@ -0,0 +1,17 @@ +tinygrad is a bit bloated now, and there's several places where concerns should be seperated and they aren't. + +tensor.py and mlops.py are great code. The interface going backward here is: + +LazyBuffer.const (this creates a matching size buffer) +LazyBuffer.contiguous (tbis is not exactly elementwise) +LazyBuffer.e (elementwise) +LazyBuffer.r (reduce) +reshape/permute/expand/stride/shrink/pad (movement) + +The lazy.py reordering engine has a lot of junk to deal with movementops that should be removed. + +view.py is mostly great code, except it shouldn't have the rendering logic, and the int type should be parameterized to not import from symbolic. + +LazyOp shouldn't have LazyBuffers as sources, just LazyOp LoadOps with a tuple of Views. Then the LazyOp uniquely determines the kernel and we don't have to do any replacement. + +ShapeTracker probably shouldn't exist and just be a part of LazyBuffer. Most of the stuff in ShapeTracker should move to symbolic_view, which combines view and symbolic. diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 80c61f8e42..f863855505 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -243,7 +243,7 @@ class LazyBuffer: srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,) return create_lazybuffer(self.device, ShapeTracker(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype, self.var_vals) - def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer: + def r(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer: if any(not isinstance(s, int) for s in self.shape) or prod(self.shape) // prod(new_shape) < 32768: return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach. heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, old))/(stride or math.inf), divisor, i) for i, (old, new, stride) in enumerate(zip(self.shape, new_shape, self.st.real_strides())) if old != new) # type: ignore if divisor < 16 or heuristic < 0.1: return self._reduce_op(op, new_shape) # Choose largest divisor (>=16) to split on, penalize large strides. @@ -289,7 +289,7 @@ class LazyBuffer: src, rop = self.op.src[0], self.op.op src.children.discard(self) del self # TODO: why doesn't this delete remove it from the children - return src.permute(arg).reduce_op(cast(ReduceOps, rop), narg) + return src.permute(arg).r(cast(ReduceOps, rop), narg) # move permutes before expands (always, this is safe) if self.op.op == MovementOps.EXPAND: diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index fd0f1c7db7..0bbb911601 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -89,20 +89,20 @@ class Sigmoid(Function): class Sum(Function): def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer: self.input_shape = x.shape - return x.reduce_op(ReduceOps.SUM, new_shape) + return x.r(ReduceOps.SUM, new_shape) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape) class Max(Function): def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer: - self.x, self.ret = x, x.reduce_op(ReduceOps.MAX, new_shape) + self.x, self.ret = x, x.r(ReduceOps.MAX, new_shape) return self.ret def backward(self, grad_output:LazyBuffer) -> LazyBuffer: # 1s in locations where the max was chosen (can be two locations) max_is_1s = self.x.const(1.0).e(BinaryOps.SUB, self.x.e(BinaryOps.CMPLT, self.ret.expand(self.x.shape))) - div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).expand(self.x.shape) + div = max_is_1s.r(ReduceOps.SUM, grad_output.shape).expand(self.x.shape) return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape)) # ************* binary ops ************* @@ -166,7 +166,7 @@ class Expand(Function): return x.expand(shape) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.reduce_op(ReduceOps.SUM, self.input_shape) + return grad_output.r(ReduceOps.SUM, self.input_shape) class Reshape(Function): def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer: