diff --git a/test/null/test_tensor_uop_mixin.py b/test/null/test_tensor_uop_mixin.py index 2f4a9e0b59..2e5ce01ea5 100644 --- a/test/null/test_tensor_uop_mixin.py +++ b/test/null/test_tensor_uop_mixin.py @@ -1,61 +1,62 @@ import math, unittest from tinygrad import Tensor -# TODO: make all the expectedFailure cases pass — i.e. UOp.__getitem__ should produce the same UOp graph as -# Tensor.__getitem__ for every view-returning index pattern. - def _t(*shape): return Tensor.arange(math.prod(shape)).reshape(*shape) +# Tensor().func().uop should be the same as UOp.func() +def _check(tc: unittest.TestCase, t: Tensor, fn): + tc.assertIs(fn(t).uop, fn(t.uop), f"\ntensor.uop = {fn(t).uop}\nuop = {fn(t.uop)}") + class TestTensorUOpGetitem(unittest.TestCase): - """For each pattern, check that `Tensor(x)[idx].uop` equals `x.uop[idx]`.""" - - def _check(self, t: Tensor, idx): - via_tensor = t[idx].uop - via_uop = t.uop[idx] - self.assertIs(via_tensor, via_uop, f"\nidx={idx!r}\ntensor.uop = {via_tensor}\nuop[idx] = {via_uop}") - # ---- pure slice patterns ---- - def test_slice_full(self): self._check(_t(4), slice(None)) - def test_slice_positive(self): self._check(_t(8), slice(1, 5)) - def test_slice_open_start(self): self._check(_t(8), slice(None, 5)) - def test_slice_open_stop(self): self._check(_t(8), slice(3, None)) - def test_slice_negative_start(self): self._check(_t(8), slice(-3, None)) - def test_slice_negative_stop(self): self._check(_t(8), slice(None, -2)) - def test_slice_both_negative(self): self._check(_t(8), slice(-5, -1)) + def test_slice_full(self): _check(self, _t(4), lambda x: x[slice(None)]) + def test_slice_positive(self): _check(self, _t(8), lambda x: x[1:5]) + def test_slice_open_start(self): _check(self, _t(8), lambda x: x[:5]) + def test_slice_open_stop(self): _check(self, _t(8), lambda x: x[3:]) + def test_slice_negative_start(self): _check(self, _t(8), lambda x: x[-3:]) + def test_slice_negative_stop(self): _check(self, _t(8), lambda x: x[:-2]) + def test_slice_both_negative(self): _check(self, _t(8), lambda x: x[-5:-1]) # ---- slice with stride ---- - def test_slice_stride(self): self._check(_t(6), slice(None, None, 2)) - def test_slice_start_stop_stride(self): self._check(_t(6), slice(1, 5, 2)) - def test_slice_reverse(self): self._check(_t(6), slice(None, None, -1)) - def test_slice_singleton_negative_step(self): self._check(_t(8), slice(3, 2, -1)) + def test_slice_stride(self): _check(self, _t(6), lambda x: x[::2]) + def test_slice_start_stop_stride(self): _check(self, _t(6), lambda x: x[1:5:2]) + def test_slice_reverse(self): _check(self, _t(6), lambda x: x[::-1]) + def test_slice_singleton_negative_step(self): _check(self, _t(8), lambda x: x[3:2:-1]) # ---- empty / out-of-bounds slice ---- - def test_slice_empty(self): self._check(_t(6), slice(3, 1)) - def test_slice_oob_stop(self): self._check(_t(6), slice(0, 100)) + def test_slice_empty(self): _check(self, _t(6), lambda x: x[3:1]) + def test_slice_oob_stop(self): _check(self, _t(6), lambda x: x[0:100]) # ---- single int (reduces a dim) ---- - def test_int_positive(self): self._check(_t(8), 3) - def test_int_negative(self): self._check(_t(8), -1) + def test_int_positive(self): _check(self, _t(8), lambda x: x[3]) + def test_int_negative(self): _check(self, _t(8), lambda x: x[-1]) # ---- ellipsis ---- - def test_ellipsis_only(self): self._check(_t(2, 3, 4), (Ellipsis,)) - def test_ellipsis_then_int(self): self._check(_t(2, 3, 4), (Ellipsis, -1)) - def test_ellipsis_then_slice(self): self._check(_t(2, 3, 4), (Ellipsis, slice(1, 3))) - def test_ellipsis_then_none(self): self._check(_t(2, 3), (Ellipsis, None)) + def test_ellipsis_only(self): _check(self, _t(2, 3, 4), lambda x: x[...]) + def test_ellipsis_then_int(self): _check(self, _t(2, 3, 4), lambda x: x[..., -1]) + def test_ellipsis_then_slice(self): _check(self, _t(2, 3, 4), lambda x: x[..., 1:3]) + def test_ellipsis_then_none(self): _check(self, _t(2, 3), lambda x: x[..., None]) # ---- None (unsqueeze) ---- - def test_none_front(self): self._check(_t(4), (None,)) - def test_none_back(self): self._check(_t(4), (slice(None), None)) - def test_none_middle(self): self._check(_t(2, 3), (slice(None), None, slice(None))) - def test_multiple_none(self): self._check(_t(2, 3), (None, slice(None), None)) + def test_none_front(self): _check(self, _t(4), lambda x: x[None]) + def test_none_back(self): _check(self, _t(4), lambda x: x[:, None]) + def test_none_middle(self): _check(self, _t(2, 3), lambda x: x[:, None, :]) + def test_multiple_none(self): _check(self, _t(2, 3), lambda x: x[None, :, None]) # ---- mixed multi-dim ---- - def test_int_then_slice(self): self._check(_t(2, 3), (1, slice(None))) - def test_multi_int(self): self._check(_t(2, 3, 4), (1, 2)) - def test_mixed_slice_int(self): self._check(_t(2, 3, 4), (slice(0, 2), -1, slice(1, 3))) - def test_mixed_slice_slice(self): self._check(_t(3, 4, 5), (slice(1, 3), slice(None), slice(0, 2))) - def test_high_rank_combo(self): self._check(_t(4, 5, 6), (slice(1, 3), slice(None), -1, None)) + def test_int_then_slice(self): _check(self, _t(2, 3), lambda x: x[1, :]) + def test_multi_int(self): _check(self, _t(2, 3, 4), lambda x: x[1, 2]) + def test_mixed_slice_int(self): _check(self, _t(2, 3, 4), lambda x: x[0:2, -1, 1:3]) + def test_mixed_slice_slice(self): _check(self, _t(3, 4, 5), lambda x: x[1:3, :, 0:2]) + def test_high_rank_combo(self): _check(self, _t(4, 5, 6), lambda x: x[1:3, :, -1, None]) + +class TestTensorUOpCumalu(unittest.TestCase): + def test_cumsum_1d(self): _check(self, _t(5), lambda x: x.cumsum()) + def test_cumsum_2d(self): _check(self, _t(3, 4), lambda x: x.cumsum(1)) + def test_cumsum_non_last(self): _check(self, _t(3, 4), lambda x: x.cumsum(0)) + def test_cumsum_large(self): _check(self, _t(600), lambda x: x.cumsum()) # exercises _split_cumalu + def test_cumprod(self): _check(self, _t(4), lambda x: x.cumprod(0)) if __name__ == "__main__": unittest.main() diff --git a/tinygrad/mixin/__init__.py b/tinygrad/mixin/__init__.py index 38b8228042..efd645e4af 100644 --- a/tinygrad/mixin/__init__.py +++ b/tinygrad/mixin/__init__.py @@ -1,15 +1,26 @@ import functools from typing import Self, Sequence, Literal, get_args from tinygrad.mixin.elementwise import ElementwiseMixin +from tinygrad.mixin.movement import MovementMixin from tinygrad.mixin.reduce import ReduceMixin -from tinygrad.uop.ops import _broadcast_shape, resolve +from tinygrad.uop import Ops +from tinygrad.uop.ops import _broadcast_shape, resolve, smax, smin, identity_element from tinygrad.dtype import DTypeLike, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype -from tinygrad.helpers import argfix, prod +from tinygrad.helpers import argfix, flatten, prod, round_up ReductionStr = Literal["mean", "sum", "none"] class OpMixin(ElementwiseMixin, ReduceMixin): + def _pad_constant(self, pX, value:float) -> Self: + # shrink first for negative pads, then pad with only non-negative values + pX = tuple((0, 0) if p is None else p for p in pX) + has_neg = not all(resolve(p >= 0) for p in flatten(pX)) + X = self.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, self.shape))) if has_neg else self + pads = tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX) if has_neg else pX + if value == 0: return MovementMixin.pad(X, pads) + return MovementMixin.pad(X, pads) + MovementMixin.pad(X.ones_like(), pads).cast(dtypes.bool).where(0, value) + def _broadcasted(self, y, reverse=False) -> tuple[Self, Self]: if not isinstance(y, type(self)): y = self.ufix(y) x, y = (self, y) if not reverse else (y, self) @@ -249,6 +260,54 @@ class OpMixin(ElementwiseMixin, ReduceMixin): m = self.max(axis=axis, keepdim=True) return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + (m if keepdim else m.squeeze(axis)) + def _cumalu(self, axis:int, op:Ops) -> Self: + assert self.shape[axis] != 0 and op in (Ops.ADD, Ops.MAX, Ops.MUL) + pads = (None,)*(self.ndim-1) + ((self.shape[axis]-1, 0),) + pooled = self.transpose(axis,-1)._pad_constant(pads, identity_element(op, self.dtype))._pool((self.shape[axis],)) + return getattr(pooled, {Ops.ADD: "sum", Ops.MAX: "max", Ops.MUL: "prod"}[op])(-1).transpose(axis, -1) + + def _split_cumalu(self, axis:int, op:Ops) -> Self: + axis = self._resolve_dim(axis) + if self.ndim == 0 or 0 in self.shape: return self + # TODO: someday the optimizer will find this on its own + # for now this is a two stage cumsum + SPLIT = 256 + value = identity_element(op, self.dtype) + if not isinstance(s:=self.shape[axis], int) or s <= SPLIT*2: return self._cumalu(axis, op) + ret = self.transpose(axis,-1)._pad_constant((None,)*(self.ndim-1)+((round_up(s,SPLIT)-s,0),), value).unflatten(-1,(-1,SPLIT))._cumalu(-1, op) + base = ret[..., -1]._cumalu(-1, op)._pad_constant((None,)*(ret.ndim-2) + ((1, -1),), value) + base = base.unsqueeze(-1).expand(*base.shape, ret.shape[-1]) + def fix(x: Self) -> Self: return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1) + return getattr(fix(ret), {Ops.ADD: "add", Ops.MAX: "maximum", Ops.MUL: "mul"}[op])(fix(base)) + + def cumsum(self, axis:int=0) -> Self: + """ + Computes the cumulative sum of the tensor along the specified `axis`. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor.ones(2, 3) + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.cumsum(1).numpy()) + ``` + """ + return self._split_cumalu(axis, Ops.ADD) + + def cumprod(self, axis:int) -> Self: + """ + Computes the cumulative product of the elements of the tensor along the specified `axis`. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor.arange(1, 7).reshape(2, 3) + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.cumprod(axis=0).numpy()) + ``` + """ + return self._split_cumalu(axis, Ops.MUL) + # ***** functional nn ops ***** def linear(self, weight:Self, bias:Self|None=None, dtype:DTypeLike|None=None) -> Self: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 5d7abd535c..8961e98bd7 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -11,7 +11,7 @@ from tinygrad.helpers import IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, ceildiv, from tinygrad.helpers import suppress_finalizing, disable_gc from tinygrad.gradient import compute_gradient from tinygrad.mixin import OpMixin, ReductionStr -from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, Variable +from tinygrad.uop.ops import smax, UOp, Ops, sint, all_metadata, _index_to_concrete_int, sint_to_uop, Variable from tinygrad.uop.ops import _broadcast_shape from tinygrad.engine.schedule import ExecItem, complete_create_schedule_with_vars from tinygrad.device import Buffer, canonicalize_device @@ -1054,14 +1054,6 @@ class Tensor(OpMixin): def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, extra_args=(op,), arg=arg) def _rop(self, op:Ops, axis:tuple[int, ...]) -> Tensor: return self._apply_uop(UOp._rop, op=op, axis=axis) - def _pad_constant(self, pX:tuple[tuple[sint, sint], ...], value:float) -> Tensor: - # shrink first for negative pads, then pad with only non-negative values - has_neg = not all(resolve(p >= 0) for p in flatten(pX)) - X = self.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, self.shape))) if has_neg else self - pads = tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX) if has_neg else pX - if value == 0: return X._apply_uop(UOp.pad, arg=pads) - return X._apply_uop(UOp.pad, arg=pads) + Tensor.ones_like(X)._apply_uop(UOp.pad, arg=pads).where(0, value) - def _pad_circular(self, pX:tuple[tuple[sint, sint], ...]) -> Tensor: if any(pB>sh or pA>sh for (pB,pA),sh in zip(pX, self.shape)): raise ValueError('Padding value causes wrapping around more than once.') if any(pB<0 or pA<0 for pB,pA in pX): raise NotImplementedError("Negative pads with circular pads is not supported") @@ -1921,53 +1913,6 @@ class Tensor(OpMixin): if IMAGE: return self.image_dot(w, dtype) return super().dot(w, dtype=dtype) - def _cumalu(self, axis:int, op:Ops) -> Tensor: - assert self.shape[axis] != 0 and op in (Ops.ADD, Ops.MAX, Ops.MUL) - pooled = self.transpose(axis,-1).pad((self.shape[axis]-1, 0), value=identity_element(op, self.dtype))._pool((self.shape[axis],)) - return getattr(pooled, {Ops.ADD: "sum", Ops.MAX: "max", Ops.MUL: "prod"}[op])(-1).transpose(axis, -1) - - def _split_cumalu(self, axis:int, op:Ops) -> Tensor: - axis = self._resolve_dim(axis) - if self.ndim == 0 or 0 in self.shape: return self - # TODO: someday the optimizer will find this on its own - # for now this is a two stage cumsum - SPLIT = 256 - value = identity_element(op, self.dtype) - if not isinstance(s:=self.shape[axis], int) or s <= SPLIT*2: return self._cumalu(axis, op) - ret = self.transpose(axis,-1).pad((round_up(s, SPLIT)-s, 0), value=value).unflatten(-1, (-1, SPLIT))._cumalu(-1, op) - base = ret[..., -1]._cumalu(-1, op).pad((1, -1), value=value) - base = base.unsqueeze(-1).expand(*base.shape, ret.shape[-1]) - def fix(x: Tensor) -> Tensor: return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1) - return getattr(fix(ret), {Ops.ADD: "add", Ops.MAX: "maximum", Ops.MUL: "mul"}[op])(fix(base)) - - def cumsum(self, axis:int=0) -> Tensor: - """ - Computes the cumulative sum of the tensor along the specified `axis`. - - ```python exec="true" source="above" session="tensor" result="python" - t = Tensor.ones(2, 3) - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.cumsum(1).numpy()) - ``` - """ - return self._split_cumalu(axis, Ops.ADD) - - def cumprod(self, axis:int) -> Tensor: - """ - Computes the cumulative product of the elements of the tensor along the specified `axis`. - - ```python exec="true" source="above" session="tensor" result="python" - t = Tensor.arange(1, 7).reshape(2, 3) - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.cumprod(axis=0).numpy()) - ``` - """ - return self._split_cumalu(axis, Ops.MUL) - def cummax(self, axis:int=0) -> tuple[Tensor, Tensor]: """ Computes the cumulative max of the tensor along `axis`, returning (values, indices).