mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -1,61 +1,62 @@
|
||||
import math, unittest
|
||||
from tinygrad import Tensor
|
||||
|
||||
# TODO: make all the expectedFailure cases pass — i.e. UOp.__getitem__ should produce the same UOp graph as
|
||||
# Tensor.__getitem__ for every view-returning index pattern.
|
||||
|
||||
def _t(*shape):
|
||||
return Tensor.arange(math.prod(shape)).reshape(*shape)
|
||||
|
||||
# Tensor().func().uop should be the same as UOp.func()
|
||||
def _check(tc: unittest.TestCase, t: Tensor, fn):
|
||||
tc.assertIs(fn(t).uop, fn(t.uop), f"\ntensor.uop = {fn(t).uop}\nuop = {fn(t.uop)}")
|
||||
|
||||
class TestTensorUOpGetitem(unittest.TestCase):
|
||||
"""For each pattern, check that `Tensor(x)[idx].uop` equals `x.uop[idx]`."""
|
||||
|
||||
def _check(self, t: Tensor, idx):
|
||||
via_tensor = t[idx].uop
|
||||
via_uop = t.uop[idx]
|
||||
self.assertIs(via_tensor, via_uop, f"\nidx={idx!r}\ntensor.uop = {via_tensor}\nuop[idx] = {via_uop}")
|
||||
|
||||
# ---- pure slice patterns ----
|
||||
def test_slice_full(self): self._check(_t(4), slice(None))
|
||||
def test_slice_positive(self): self._check(_t(8), slice(1, 5))
|
||||
def test_slice_open_start(self): self._check(_t(8), slice(None, 5))
|
||||
def test_slice_open_stop(self): self._check(_t(8), slice(3, None))
|
||||
def test_slice_negative_start(self): self._check(_t(8), slice(-3, None))
|
||||
def test_slice_negative_stop(self): self._check(_t(8), slice(None, -2))
|
||||
def test_slice_both_negative(self): self._check(_t(8), slice(-5, -1))
|
||||
def test_slice_full(self): _check(self, _t(4), lambda x: x[slice(None)])
|
||||
def test_slice_positive(self): _check(self, _t(8), lambda x: x[1:5])
|
||||
def test_slice_open_start(self): _check(self, _t(8), lambda x: x[:5])
|
||||
def test_slice_open_stop(self): _check(self, _t(8), lambda x: x[3:])
|
||||
def test_slice_negative_start(self): _check(self, _t(8), lambda x: x[-3:])
|
||||
def test_slice_negative_stop(self): _check(self, _t(8), lambda x: x[:-2])
|
||||
def test_slice_both_negative(self): _check(self, _t(8), lambda x: x[-5:-1])
|
||||
|
||||
# ---- slice with stride ----
|
||||
def test_slice_stride(self): self._check(_t(6), slice(None, None, 2))
|
||||
def test_slice_start_stop_stride(self): self._check(_t(6), slice(1, 5, 2))
|
||||
def test_slice_reverse(self): self._check(_t(6), slice(None, None, -1))
|
||||
def test_slice_singleton_negative_step(self): self._check(_t(8), slice(3, 2, -1))
|
||||
def test_slice_stride(self): _check(self, _t(6), lambda x: x[::2])
|
||||
def test_slice_start_stop_stride(self): _check(self, _t(6), lambda x: x[1:5:2])
|
||||
def test_slice_reverse(self): _check(self, _t(6), lambda x: x[::-1])
|
||||
def test_slice_singleton_negative_step(self): _check(self, _t(8), lambda x: x[3:2:-1])
|
||||
|
||||
# ---- empty / out-of-bounds slice ----
|
||||
def test_slice_empty(self): self._check(_t(6), slice(3, 1))
|
||||
def test_slice_oob_stop(self): self._check(_t(6), slice(0, 100))
|
||||
def test_slice_empty(self): _check(self, _t(6), lambda x: x[3:1])
|
||||
def test_slice_oob_stop(self): _check(self, _t(6), lambda x: x[0:100])
|
||||
|
||||
# ---- single int (reduces a dim) ----
|
||||
def test_int_positive(self): self._check(_t(8), 3)
|
||||
def test_int_negative(self): self._check(_t(8), -1)
|
||||
def test_int_positive(self): _check(self, _t(8), lambda x: x[3])
|
||||
def test_int_negative(self): _check(self, _t(8), lambda x: x[-1])
|
||||
|
||||
# ---- ellipsis ----
|
||||
def test_ellipsis_only(self): self._check(_t(2, 3, 4), (Ellipsis,))
|
||||
def test_ellipsis_then_int(self): self._check(_t(2, 3, 4), (Ellipsis, -1))
|
||||
def test_ellipsis_then_slice(self): self._check(_t(2, 3, 4), (Ellipsis, slice(1, 3)))
|
||||
def test_ellipsis_then_none(self): self._check(_t(2, 3), (Ellipsis, None))
|
||||
def test_ellipsis_only(self): _check(self, _t(2, 3, 4), lambda x: x[...])
|
||||
def test_ellipsis_then_int(self): _check(self, _t(2, 3, 4), lambda x: x[..., -1])
|
||||
def test_ellipsis_then_slice(self): _check(self, _t(2, 3, 4), lambda x: x[..., 1:3])
|
||||
def test_ellipsis_then_none(self): _check(self, _t(2, 3), lambda x: x[..., None])
|
||||
|
||||
# ---- None (unsqueeze) ----
|
||||
def test_none_front(self): self._check(_t(4), (None,))
|
||||
def test_none_back(self): self._check(_t(4), (slice(None), None))
|
||||
def test_none_middle(self): self._check(_t(2, 3), (slice(None), None, slice(None)))
|
||||
def test_multiple_none(self): self._check(_t(2, 3), (None, slice(None), None))
|
||||
def test_none_front(self): _check(self, _t(4), lambda x: x[None])
|
||||
def test_none_back(self): _check(self, _t(4), lambda x: x[:, None])
|
||||
def test_none_middle(self): _check(self, _t(2, 3), lambda x: x[:, None, :])
|
||||
def test_multiple_none(self): _check(self, _t(2, 3), lambda x: x[None, :, None])
|
||||
|
||||
# ---- mixed multi-dim ----
|
||||
def test_int_then_slice(self): self._check(_t(2, 3), (1, slice(None)))
|
||||
def test_multi_int(self): self._check(_t(2, 3, 4), (1, 2))
|
||||
def test_mixed_slice_int(self): self._check(_t(2, 3, 4), (slice(0, 2), -1, slice(1, 3)))
|
||||
def test_mixed_slice_slice(self): self._check(_t(3, 4, 5), (slice(1, 3), slice(None), slice(0, 2)))
|
||||
def test_high_rank_combo(self): self._check(_t(4, 5, 6), (slice(1, 3), slice(None), -1, None))
|
||||
def test_int_then_slice(self): _check(self, _t(2, 3), lambda x: x[1, :])
|
||||
def test_multi_int(self): _check(self, _t(2, 3, 4), lambda x: x[1, 2])
|
||||
def test_mixed_slice_int(self): _check(self, _t(2, 3, 4), lambda x: x[0:2, -1, 1:3])
|
||||
def test_mixed_slice_slice(self): _check(self, _t(3, 4, 5), lambda x: x[1:3, :, 0:2])
|
||||
def test_high_rank_combo(self): _check(self, _t(4, 5, 6), lambda x: x[1:3, :, -1, None])
|
||||
|
||||
class TestTensorUOpCumalu(unittest.TestCase):
|
||||
def test_cumsum_1d(self): _check(self, _t(5), lambda x: x.cumsum())
|
||||
def test_cumsum_2d(self): _check(self, _t(3, 4), lambda x: x.cumsum(1))
|
||||
def test_cumsum_non_last(self): _check(self, _t(3, 4), lambda x: x.cumsum(0))
|
||||
def test_cumsum_large(self): _check(self, _t(600), lambda x: x.cumsum()) # exercises _split_cumalu
|
||||
def test_cumprod(self): _check(self, _t(4), lambda x: x.cumprod(0))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,15 +1,26 @@
|
||||
import functools
|
||||
from typing import Self, Sequence, Literal, get_args
|
||||
from tinygrad.mixin.elementwise import ElementwiseMixin
|
||||
from tinygrad.mixin.movement import MovementMixin
|
||||
from tinygrad.mixin.reduce import ReduceMixin
|
||||
from tinygrad.uop.ops import _broadcast_shape, resolve
|
||||
from tinygrad.uop import Ops
|
||||
from tinygrad.uop.ops import _broadcast_shape, resolve, smax, smin, identity_element
|
||||
from tinygrad.dtype import DTypeLike, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
|
||||
from tinygrad.helpers import argfix, prod
|
||||
from tinygrad.helpers import argfix, flatten, prod, round_up
|
||||
|
||||
ReductionStr = Literal["mean", "sum", "none"]
|
||||
|
||||
|
||||
class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||
def _pad_constant(self, pX, value:float) -> Self:
|
||||
# shrink first for negative pads, then pad with only non-negative values
|
||||
pX = tuple((0, 0) if p is None else p for p in pX)
|
||||
has_neg = not all(resolve(p >= 0) for p in flatten(pX))
|
||||
X = self.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, self.shape))) if has_neg else self
|
||||
pads = tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX) if has_neg else pX
|
||||
if value == 0: return MovementMixin.pad(X, pads)
|
||||
return MovementMixin.pad(X, pads) + MovementMixin.pad(X.ones_like(), pads).cast(dtypes.bool).where(0, value)
|
||||
|
||||
def _broadcasted(self, y, reverse=False) -> tuple[Self, Self]:
|
||||
if not isinstance(y, type(self)): y = self.ufix(y)
|
||||
x, y = (self, y) if not reverse else (y, self)
|
||||
@@ -249,6 +260,54 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||
m = self.max(axis=axis, keepdim=True)
|
||||
return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + (m if keepdim else m.squeeze(axis))
|
||||
|
||||
def _cumalu(self, axis:int, op:Ops) -> Self:
|
||||
assert self.shape[axis] != 0 and op in (Ops.ADD, Ops.MAX, Ops.MUL)
|
||||
pads = (None,)*(self.ndim-1) + ((self.shape[axis]-1, 0),)
|
||||
pooled = self.transpose(axis,-1)._pad_constant(pads, identity_element(op, self.dtype))._pool((self.shape[axis],))
|
||||
return getattr(pooled, {Ops.ADD: "sum", Ops.MAX: "max", Ops.MUL: "prod"}[op])(-1).transpose(axis, -1)
|
||||
|
||||
def _split_cumalu(self, axis:int, op:Ops) -> Self:
|
||||
axis = self._resolve_dim(axis)
|
||||
if self.ndim == 0 or 0 in self.shape: return self
|
||||
# TODO: someday the optimizer will find this on its own
|
||||
# for now this is a two stage cumsum
|
||||
SPLIT = 256
|
||||
value = identity_element(op, self.dtype)
|
||||
if not isinstance(s:=self.shape[axis], int) or s <= SPLIT*2: return self._cumalu(axis, op)
|
||||
ret = self.transpose(axis,-1)._pad_constant((None,)*(self.ndim-1)+((round_up(s,SPLIT)-s,0),), value).unflatten(-1,(-1,SPLIT))._cumalu(-1, op)
|
||||
base = ret[..., -1]._cumalu(-1, op)._pad_constant((None,)*(ret.ndim-2) + ((1, -1),), value)
|
||||
base = base.unsqueeze(-1).expand(*base.shape, ret.shape[-1])
|
||||
def fix(x: Self) -> Self: return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
|
||||
return getattr(fix(ret), {Ops.ADD: "add", Ops.MAX: "maximum", Ops.MUL: "mul"}[op])(fix(base))
|
||||
|
||||
def cumsum(self, axis:int=0) -> Self:
|
||||
"""
|
||||
Computes the cumulative sum of the tensor along the specified `axis`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.ones(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.cumsum(1).numpy())
|
||||
```
|
||||
"""
|
||||
return self._split_cumalu(axis, Ops.ADD)
|
||||
|
||||
def cumprod(self, axis:int) -> Self:
|
||||
"""
|
||||
Computes the cumulative product of the elements of the tensor along the specified `axis`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(1, 7).reshape(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.cumprod(axis=0).numpy())
|
||||
```
|
||||
"""
|
||||
return self._split_cumalu(axis, Ops.MUL)
|
||||
|
||||
# ***** functional nn ops *****
|
||||
|
||||
def linear(self, weight:Self, bias:Self|None=None, dtype:DTypeLike|None=None) -> Self:
|
||||
|
||||
@@ -11,7 +11,7 @@ from tinygrad.helpers import IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, ceildiv,
|
||||
from tinygrad.helpers import suppress_finalizing, disable_gc
|
||||
from tinygrad.gradient import compute_gradient
|
||||
from tinygrad.mixin import OpMixin, ReductionStr
|
||||
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, Variable
|
||||
from tinygrad.uop.ops import smax, UOp, Ops, sint, all_metadata, _index_to_concrete_int, sint_to_uop, Variable
|
||||
from tinygrad.uop.ops import _broadcast_shape
|
||||
from tinygrad.engine.schedule import ExecItem, complete_create_schedule_with_vars
|
||||
from tinygrad.device import Buffer, canonicalize_device
|
||||
@@ -1054,14 +1054,6 @@ class Tensor(OpMixin):
|
||||
def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, extra_args=(op,), arg=arg)
|
||||
def _rop(self, op:Ops, axis:tuple[int, ...]) -> Tensor: return self._apply_uop(UOp._rop, op=op, axis=axis)
|
||||
|
||||
def _pad_constant(self, pX:tuple[tuple[sint, sint], ...], value:float) -> Tensor:
|
||||
# shrink first for negative pads, then pad with only non-negative values
|
||||
has_neg = not all(resolve(p >= 0) for p in flatten(pX))
|
||||
X = self.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, self.shape))) if has_neg else self
|
||||
pads = tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX) if has_neg else pX
|
||||
if value == 0: return X._apply_uop(UOp.pad, arg=pads)
|
||||
return X._apply_uop(UOp.pad, arg=pads) + Tensor.ones_like(X)._apply_uop(UOp.pad, arg=pads).where(0, value)
|
||||
|
||||
def _pad_circular(self, pX:tuple[tuple[sint, sint], ...]) -> Tensor:
|
||||
if any(pB>sh or pA>sh for (pB,pA),sh in zip(pX, self.shape)): raise ValueError('Padding value causes wrapping around more than once.')
|
||||
if any(pB<0 or pA<0 for pB,pA in pX): raise NotImplementedError("Negative pads with circular pads is not supported")
|
||||
@@ -1921,53 +1913,6 @@ class Tensor(OpMixin):
|
||||
if IMAGE: return self.image_dot(w, dtype)
|
||||
return super().dot(w, dtype=dtype)
|
||||
|
||||
def _cumalu(self, axis:int, op:Ops) -> Tensor:
|
||||
assert self.shape[axis] != 0 and op in (Ops.ADD, Ops.MAX, Ops.MUL)
|
||||
pooled = self.transpose(axis,-1).pad((self.shape[axis]-1, 0), value=identity_element(op, self.dtype))._pool((self.shape[axis],))
|
||||
return getattr(pooled, {Ops.ADD: "sum", Ops.MAX: "max", Ops.MUL: "prod"}[op])(-1).transpose(axis, -1)
|
||||
|
||||
def _split_cumalu(self, axis:int, op:Ops) -> Tensor:
|
||||
axis = self._resolve_dim(axis)
|
||||
if self.ndim == 0 or 0 in self.shape: return self
|
||||
# TODO: someday the optimizer will find this on its own
|
||||
# for now this is a two stage cumsum
|
||||
SPLIT = 256
|
||||
value = identity_element(op, self.dtype)
|
||||
if not isinstance(s:=self.shape[axis], int) or s <= SPLIT*2: return self._cumalu(axis, op)
|
||||
ret = self.transpose(axis,-1).pad((round_up(s, SPLIT)-s, 0), value=value).unflatten(-1, (-1, SPLIT))._cumalu(-1, op)
|
||||
base = ret[..., -1]._cumalu(-1, op).pad((1, -1), value=value)
|
||||
base = base.unsqueeze(-1).expand(*base.shape, ret.shape[-1])
|
||||
def fix(x: Tensor) -> Tensor: return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
|
||||
return getattr(fix(ret), {Ops.ADD: "add", Ops.MAX: "maximum", Ops.MUL: "mul"}[op])(fix(base))
|
||||
|
||||
def cumsum(self, axis:int=0) -> Tensor:
|
||||
"""
|
||||
Computes the cumulative sum of the tensor along the specified `axis`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.ones(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.cumsum(1).numpy())
|
||||
```
|
||||
"""
|
||||
return self._split_cumalu(axis, Ops.ADD)
|
||||
|
||||
def cumprod(self, axis:int) -> Tensor:
|
||||
"""
|
||||
Computes the cumulative product of the elements of the tensor along the specified `axis`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(1, 7).reshape(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.cumprod(axis=0).numpy())
|
||||
```
|
||||
"""
|
||||
return self._split_cumalu(axis, Ops.MUL)
|
||||
|
||||
def cummax(self, axis:int=0) -> tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Computes the cumulative max of the tensor along `axis`, returning (values, indices).
|
||||
|
||||
Reference in New Issue
Block a user