mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Tensor.cummax (#7854)
generalized the existing cumsum and take Ops.MAX in addition to Ops.ADD
This commit is contained in:
@@ -27,6 +27,7 @@
|
||||
::: tinygrad.Tensor.matmul
|
||||
::: tinygrad.Tensor.einsum
|
||||
::: tinygrad.Tensor.cumsum
|
||||
::: tinygrad.Tensor.cummax
|
||||
::: tinygrad.Tensor.triu
|
||||
::: tinygrad.Tensor.tril
|
||||
::: tinygrad.Tensor.interpolate
|
||||
|
||||
@@ -135,7 +135,7 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
|
||||
|
||||
# approximate top p
|
||||
# because we are already limited to top k elements we can do top p "without sorting"
|
||||
output_cumsum = output[::-1]._cumsum()[::-1] + t.sum()
|
||||
output_cumsum = output[::-1].cumsum()[::-1] + t.sum()
|
||||
output = (output_cumsum >= (1 - p)) * output
|
||||
output_indices = (output_cumsum >= (1 - p)) * output_indices
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.codegen.kernel import Opt, OptOps, Kernel, KernelOptError
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem
|
||||
from tinygrad.engine.search import get_kernel_actions
|
||||
from tinygrad.ops import Ops
|
||||
|
||||
class TestArange(unittest.TestCase):
|
||||
def _get_flops(self, N, opts=None):
|
||||
@@ -86,7 +87,7 @@ class TestIndexing(unittest.TestCase):
|
||||
print("*** indexing ***")
|
||||
with Context(NOOPT=1, FUSE_ARANGE=1):
|
||||
GlobalCounters.reset()
|
||||
rng = Tensor.ones(4, 256, 16384, dtype=dtypes.int)._cumsum(axis=-1, _first_zero=True).reshape(4, 256, 16384, 1)
|
||||
rng = Tensor.ones(4, 256, 16384, dtype=dtypes.int)._cumalu(axis=-1, op=Ops.ADD, _include_initial=True).reshape(4, 256, 16384, 1)
|
||||
idxs = idxs.reshape(4,1,1,1).expand(4, 256, 16384, 1)
|
||||
reshape_dataset = dataset.T.reshape(1, 256, 16384, 1).expand(4, 256, 16384, 1)
|
||||
full = (rng==idxs).where(reshape_dataset, Tensor.zeros(4, 256, 16384, 1))
|
||||
|
||||
@@ -801,6 +801,27 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(0,3)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0))
|
||||
helper_test_op([(2,3,0)], lambda x: torch.cumsum(x, dim=2), lambda x: Tensor.cumsum(x, axis=2))
|
||||
|
||||
def test_small_cummax(self):
|
||||
helper_test_op([(10)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
|
||||
def test_simple_cummax(self):
|
||||
helper_test_op([(512)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
|
||||
helper_test_op([(1022)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
|
||||
def test_cummax(self):
|
||||
helper_test_op([()], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
|
||||
# TODO: torch allows this?
|
||||
# self.helper_test_exception([()], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1), expected=IndexError)
|
||||
helper_test_op([(20,)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
|
||||
self.helper_test_exception([(20,)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1), expected=IndexError)
|
||||
self.helper_test_exception([(20,)], lambda x: torch.cummax(x, dim=-2).values, lambda x: Tensor.cummax(x, axis=-2), expected=IndexError)
|
||||
helper_test_op([(20,30)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
|
||||
helper_test_op([(20,30)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1))
|
||||
helper_test_op([(20,30,40)], lambda x: torch.cummax(x, dim=2).values, lambda x: Tensor.cummax(x, axis=2))
|
||||
helper_test_op([(20,30,40)], lambda x: torch.cummax(x, dim=-1).values, lambda x: Tensor.cummax(x, axis=-1))
|
||||
def test_cummax_zero_axis(self):
|
||||
helper_test_op([(2,0,4)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1))
|
||||
helper_test_op([(0,3)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
|
||||
helper_test_op([(2,3,0)], lambda x: torch.cummax(x, dim=2).values, lambda x: Tensor.cummax(x, axis=2))
|
||||
|
||||
def test_argmax(self):
|
||||
# check if it returns the first index for multiple occurences
|
||||
self.assertEqual(torch.tensor([2,2]).argmax().numpy(), Tensor([2,2]).argmax().numpy())
|
||||
|
||||
@@ -183,7 +183,7 @@ class GroupOp:
|
||||
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV}
|
||||
|
||||
# https://en.wikipedia.org/wiki/Identity_element
|
||||
def identity_element(op:Ops, dt:DType): return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
|
||||
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
|
||||
|
||||
def can_pad(u:UOp) -> bool: return not any(x.op in GroupOp.UnsafePad for x in u.sparents)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, leas
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
||||
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN
|
||||
from tinygrad.multi import MultiLazyBuffer
|
||||
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait
|
||||
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element
|
||||
from tinygrad.device import Device, Buffer, BufferSpec
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
@@ -607,7 +607,7 @@ class Tensor(SimpleMathTrait):
|
||||
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
|
||||
# NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs
|
||||
if (output_len:=ceildiv(stop-start, step)) <= 0: return Tensor([], dtype=dtype, **kwargs)
|
||||
return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
|
||||
return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumalu(0, Ops.ADD) + (start - step)).cast(dtype)
|
||||
|
||||
@staticmethod
|
||||
def linspace(start:Union[int, float], stop:Union[int, float], steps:int, **kwargs) -> Tensor:
|
||||
@@ -2191,15 +2191,28 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype)
|
||||
|
||||
def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor:
|
||||
assert self.shape[axis] != 0
|
||||
pl_sz = self.shape[axis] - int(not _first_zero)
|
||||
return self.transpose(axis,-1).pad((pl_sz,-int(_first_zero)))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
|
||||
def _cumalu(self, axis:int, op:Ops, _include_initial=False) -> Tensor:
|
||||
assert self.shape[axis] != 0 and op in (Ops.ADD, Ops.MAX)
|
||||
pl_sz = self.shape[axis] - int(not _include_initial)
|
||||
pooled = self.transpose(axis,-1).pad((pl_sz, -int(_include_initial)), value=identity_element(op, self.dtype))._pool((self.shape[axis],))
|
||||
return (pooled.sum(-1) if op is Ops.ADD else pooled.max(-1)).transpose(axis,-1)
|
||||
|
||||
def _split_cumalu(self, axis:int, op:Ops) -> Tensor:
|
||||
axis = self._resolve_dim(axis)
|
||||
if self.ndim == 0 or 0 in self.shape: return self
|
||||
# TODO: someday the optimizer will find this on it's own
|
||||
# for now this is a two stage cumsum
|
||||
SPLIT = 256
|
||||
if not isinstance(s:=self.shape[axis], int) or s <= SPLIT*2: return self._cumalu(axis, op)
|
||||
ret = self.transpose(axis,-1).pad((round_up(s, SPLIT)-s, 0), value=identity_element(op, self.dtype)).unflatten(-1, (-1, SPLIT))._cumalu(-1, op)
|
||||
base = ret[..., -1]._cumalu(-1, op, _include_initial=True)
|
||||
base = base.unsqueeze(-1).expand(*base.shape, ret.shape[-1])
|
||||
def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
|
||||
return fix(ret) + fix(base) if op is Ops.ADD else fix(ret).maximum(fix(base))
|
||||
|
||||
def cumsum(self, axis:int=0) -> Tensor:
|
||||
"""
|
||||
Computes the cumulative sum of the tensor along the specified axis.
|
||||
|
||||
You can pass in the `axis` keyword argument to control the axis along which the cumulative sum is computed.
|
||||
Computes the cumulative sum of the tensor along the specified `axis`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.ones(2, 3)
|
||||
@@ -2209,17 +2222,21 @@ class Tensor(SimpleMathTrait):
|
||||
print(t.cumsum(1).numpy())
|
||||
```
|
||||
"""
|
||||
axis = self._resolve_dim(axis)
|
||||
if self.ndim == 0 or 0 in self.shape: return self
|
||||
# TODO: someday the optimizer will find this on it's own
|
||||
# for now this is a two stage cumsum
|
||||
SPLIT = 256
|
||||
if not isinstance(s:=self.shape[axis], int) or s <= SPLIT*2: return self._cumsum(axis)
|
||||
ret = self.transpose(axis,-1).pad((round_up(s, SPLIT)-s, 0)).unflatten(-1, (-1, SPLIT))._cumsum(-1)
|
||||
base_add = ret[..., -1]._cumsum(-1, _first_zero=True)
|
||||
base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1])
|
||||
def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
|
||||
return fix(ret) + fix(base_add)
|
||||
return self._split_cumalu(axis, Ops.ADD)
|
||||
|
||||
def cummax(self, axis:int=0) -> Tensor:
|
||||
"""
|
||||
Computes the cumulative max of the tensor along the specified `axis`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([0, 1, -1, 2, -2, 3, -3])
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.cummax(0).numpy())
|
||||
```
|
||||
"""
|
||||
return self._split_cumalu(axis, Ops.MAX)
|
||||
|
||||
@staticmethod
|
||||
def _tri(r:sint, c:sint, diagonal:int=0, **kwargs) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user