diff --git a/docs/tensor/ops.md b/docs/tensor/ops.md index d613aad00e..e17b9873a1 100644 --- a/docs/tensor/ops.md +++ b/docs/tensor/ops.md @@ -27,6 +27,7 @@ ::: tinygrad.Tensor.matmul ::: tinygrad.Tensor.einsum ::: tinygrad.Tensor.cumsum +::: tinygrad.Tensor.cummax ::: tinygrad.Tensor.triu ::: tinygrad.Tensor.tril ::: tinygrad.Tensor.interpolate diff --git a/extra/models/llama.py b/extra/models/llama.py index 7777a13dd5..772f052095 100644 --- a/extra/models/llama.py +++ b/extra/models/llama.py @@ -135,7 +135,7 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float): # approximate top p # because we are already limited to top k elements we can do top p "without sorting" - output_cumsum = output[::-1]._cumsum()[::-1] + t.sum() + output_cumsum = output[::-1].cumsum()[::-1] + t.sum() output = (output_cumsum >= (1 - p)) * output output_indices = (output_cumsum >= (1 - p)) * output_indices diff --git a/test/test_arange.py b/test/test_arange.py index 822000661d..852112d00e 100644 --- a/test/test_arange.py +++ b/test/test_arange.py @@ -6,6 +6,7 @@ from tinygrad.engine.realize import run_schedule from tinygrad.codegen.kernel import Opt, OptOps, Kernel, KernelOptError from tinygrad.engine.realize import CompiledRunner, ExecItem from tinygrad.engine.search import get_kernel_actions +from tinygrad.ops import Ops class TestArange(unittest.TestCase): def _get_flops(self, N, opts=None): @@ -86,7 +87,7 @@ class TestIndexing(unittest.TestCase): print("*** indexing ***") with Context(NOOPT=1, FUSE_ARANGE=1): GlobalCounters.reset() - rng = Tensor.ones(4, 256, 16384, dtype=dtypes.int)._cumsum(axis=-1, _first_zero=True).reshape(4, 256, 16384, 1) + rng = Tensor.ones(4, 256, 16384, dtype=dtypes.int)._cumalu(axis=-1, op=Ops.ADD, _include_initial=True).reshape(4, 256, 16384, 1) idxs = idxs.reshape(4,1,1,1).expand(4, 256, 16384, 1) reshape_dataset = dataset.T.reshape(1, 256, 16384, 1).expand(4, 256, 16384, 1) full = (rng==idxs).where(reshape_dataset, Tensor.zeros(4, 256, 16384, 1)) diff --git a/test/test_ops.py b/test/test_ops.py index 5135ec5455..be614ae157 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -801,6 +801,27 @@ class TestOps(unittest.TestCase): helper_test_op([(0,3)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0)) helper_test_op([(2,3,0)], lambda x: torch.cumsum(x, dim=2), lambda x: Tensor.cumsum(x, axis=2)) + def test_small_cummax(self): + helper_test_op([(10)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)) + def test_simple_cummax(self): + helper_test_op([(512)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)) + helper_test_op([(1022)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)) + def test_cummax(self): + helper_test_op([()], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)) + # TODO: torch allows this? + # self.helper_test_exception([()], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1), expected=IndexError) + helper_test_op([(20,)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)) + self.helper_test_exception([(20,)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1), expected=IndexError) + self.helper_test_exception([(20,)], lambda x: torch.cummax(x, dim=-2).values, lambda x: Tensor.cummax(x, axis=-2), expected=IndexError) + helper_test_op([(20,30)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)) + helper_test_op([(20,30)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1)) + helper_test_op([(20,30,40)], lambda x: torch.cummax(x, dim=2).values, lambda x: Tensor.cummax(x, axis=2)) + helper_test_op([(20,30,40)], lambda x: torch.cummax(x, dim=-1).values, lambda x: Tensor.cummax(x, axis=-1)) + def test_cummax_zero_axis(self): + helper_test_op([(2,0,4)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1)) + helper_test_op([(0,3)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)) + helper_test_op([(2,3,0)], lambda x: torch.cummax(x, dim=2).values, lambda x: Tensor.cummax(x, axis=2)) + def test_argmax(self): # check if it returns the first index for multiple occurences self.assertEqual(torch.tensor([2,2]).argmax().numpy(), Tensor([2,2]).argmax().numpy()) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index fb61289615..5e06220848 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -183,7 +183,7 @@ class GroupOp: UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV} # https://en.wikipedia.org/wiki/Identity_element -def identity_element(op:Ops, dt:DType): return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt) +def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt) def can_pad(u:UOp) -> bool: return not any(x.op in GroupOp.UnsafePad for x in u.sparents) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index ec8b62d386..7475361a31 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -9,7 +9,7 @@ from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, leas from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN from tinygrad.multi import MultiLazyBuffer -from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait +from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element from tinygrad.device import Device, Buffer, BufferSpec from tinygrad.engine.lazy import LazyBuffer from tinygrad.engine.realize import run_schedule @@ -607,7 +607,7 @@ class Tensor(SimpleMathTrait): dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int) # NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs if (output_len:=ceildiv(stop-start, step)) <= 0: return Tensor([], dtype=dtype, **kwargs) - return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype) + return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumalu(0, Ops.ADD) + (start - step)).cast(dtype) @staticmethod def linspace(start:Union[int, float], stop:Union[int, float], steps:int, **kwargs) -> Tensor: @@ -2191,15 +2191,28 @@ class Tensor(SimpleMathTrait): """ return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype) - def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor: - assert self.shape[axis] != 0 - pl_sz = self.shape[axis] - int(not _first_zero) - return self.transpose(axis,-1).pad((pl_sz,-int(_first_zero)))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1) + def _cumalu(self, axis:int, op:Ops, _include_initial=False) -> Tensor: + assert self.shape[axis] != 0 and op in (Ops.ADD, Ops.MAX) + pl_sz = self.shape[axis] - int(not _include_initial) + pooled = self.transpose(axis,-1).pad((pl_sz, -int(_include_initial)), value=identity_element(op, self.dtype))._pool((self.shape[axis],)) + return (pooled.sum(-1) if op is Ops.ADD else pooled.max(-1)).transpose(axis,-1) + + def _split_cumalu(self, axis:int, op:Ops) -> Tensor: + axis = self._resolve_dim(axis) + if self.ndim == 0 or 0 in self.shape: return self + # TODO: someday the optimizer will find this on it's own + # for now this is a two stage cumsum + SPLIT = 256 + if not isinstance(s:=self.shape[axis], int) or s <= SPLIT*2: return self._cumalu(axis, op) + ret = self.transpose(axis,-1).pad((round_up(s, SPLIT)-s, 0), value=identity_element(op, self.dtype)).unflatten(-1, (-1, SPLIT))._cumalu(-1, op) + base = ret[..., -1]._cumalu(-1, op, _include_initial=True) + base = base.unsqueeze(-1).expand(*base.shape, ret.shape[-1]) + def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1) + return fix(ret) + fix(base) if op is Ops.ADD else fix(ret).maximum(fix(base)) + def cumsum(self, axis:int=0) -> Tensor: """ - Computes the cumulative sum of the tensor along the specified axis. - - You can pass in the `axis` keyword argument to control the axis along which the cumulative sum is computed. + Computes the cumulative sum of the tensor along the specified `axis`. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.ones(2, 3) @@ -2209,17 +2222,21 @@ class Tensor(SimpleMathTrait): print(t.cumsum(1).numpy()) ``` """ - axis = self._resolve_dim(axis) - if self.ndim == 0 or 0 in self.shape: return self - # TODO: someday the optimizer will find this on it's own - # for now this is a two stage cumsum - SPLIT = 256 - if not isinstance(s:=self.shape[axis], int) or s <= SPLIT*2: return self._cumsum(axis) - ret = self.transpose(axis,-1).pad((round_up(s, SPLIT)-s, 0)).unflatten(-1, (-1, SPLIT))._cumsum(-1) - base_add = ret[..., -1]._cumsum(-1, _first_zero=True) - base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1]) - def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1) - return fix(ret) + fix(base_add) + return self._split_cumalu(axis, Ops.ADD) + + def cummax(self, axis:int=0) -> Tensor: + """ + Computes the cumulative max of the tensor along the specified `axis`. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([0, 1, -1, 2, -2, 3, -3]) + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.cummax(0).numpy()) + ``` + """ + return self._split_cumalu(axis, Ops.MAX) @staticmethod def _tri(r:sint, c:sint, diagonal:int=0, **kwargs) -> Tensor: