mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
slicing + allclose
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, UOp, Context
|
||||
from tinygrad.dtype import AddrSpace
|
||||
from tinygrad.uop.ops import KernelInfo, AxisType
|
||||
|
||||
# **** kernels ****
|
||||
@@ -9,15 +10,18 @@ def custom_arange_kernel(C:UOp) -> UOp:
|
||||
return C[i].store(i.cast(C.dtype.base)).end(i).sink(arg=KernelInfo(name=f"custom_arange_{C.size}"))
|
||||
|
||||
def custom_add_one_kernel(B:UOp, A:UOp) -> UOp:
|
||||
A,B = A.flatten(), B.flatten()
|
||||
assert B.size == A.size
|
||||
i = UOp.range(A.size, 0)
|
||||
return B[i].store(A[i] + 1).end(i).sink(arg=KernelInfo(name=f"add_one_{A.size}"))
|
||||
|
||||
def custom_elementwise_add_kernel(C:UOp, A:UOp, B:UOp) -> UOp:
|
||||
C,A,B = C.flatten(), A.flatten(), B.flatten()
|
||||
i = UOp.range(C.size, 0)
|
||||
return C[i].store(A[i]+B[i]).end(i).sink(arg=KernelInfo(name=f"custom_add_kernel_{C.size}")).simplify()
|
||||
|
||||
def custom_elementwise_addmul_kernel(C:UOp, D:UOp, A:UOp, B:UOp) -> UOp:
|
||||
C,D,A,B = C.flatten(), D.flatten(), A.flatten(), B.flatten()
|
||||
assert C.size == D.size
|
||||
i = UOp.range(C.size, 0)
|
||||
store_c = C[i].store(A[i]+B[i])
|
||||
@@ -39,13 +43,22 @@ def custom_sum(B:UOp, A:UOp) -> UOp:
|
||||
return B.sink(arg=KernelInfo(name=f"custom_sum_{A.shape[0]}", opts_to_apply=()))
|
||||
|
||||
def flip_contract_kernel(dest:UOp, src:UOp):
|
||||
assert dest.size%4 == 0
|
||||
i = UOp.range(dest.size//4, 0)
|
||||
j = UOp.range(4, 1, AxisType.UPCAST)
|
||||
vec = src[i*4+j].contract(j)
|
||||
store = UOp.group(*[dest[i*4+k].store(vec.gep(3-k)) for k in range(4)])
|
||||
i = UOp.range(dest.shape[0], 0)
|
||||
j = UOp.range(dest.shape[1], 1, AxisType.UPCAST)
|
||||
vec = src[i, j].contract(j)
|
||||
store = UOp.group(*[dest[i, k].store(vec.gep(3-k)) for k in range(4)])
|
||||
return store.end(i).sink(arg=KernelInfo(name=f"flip_contract_{dest.size}", opts_to_apply=()))
|
||||
|
||||
def slice_sum_kernel(dest:UOp, src:UOp):
|
||||
G = UOp.range(src.shape[0], 0, AxisType.GLOBAL)
|
||||
slice_src = src[G, :]
|
||||
reg = UOp.placeholder((1,), dest.dtype.base, 0, addrspace=AddrSpace.REG)
|
||||
reg = reg.after(G)[0].set(0)
|
||||
R = UOp.range(src.shape[1], 1, AxisType.REDUCE)
|
||||
reg = reg[0].set(reg[0] + slice_src[R], end=R)
|
||||
ast = dest[G].set(reg[0], end=G)
|
||||
return ast.sink(arg=KernelInfo(name=f"slice_sum_{src.shape[0]}_{src.shape[1]}", opts_to_apply=()))
|
||||
|
||||
# **** backward callbacks ****
|
||||
|
||||
def backward_gemm(gradient:UOp, kernel:UOp) -> tuple[UOp, UOp]:
|
||||
@@ -111,6 +124,12 @@ class TestCustomKernel(unittest.TestCase):
|
||||
b = Tensor.custom_kernel(tst, a, fxn=custom_sum)[0]
|
||||
self.assertEqual(b.item(), 15)
|
||||
|
||||
def test_slice_sum(self):
|
||||
A = Tensor.randn(16, 16)
|
||||
B = Tensor.empty(16)
|
||||
B = Tensor.custom_kernel(B, A, fxn=slice_sum_kernel)[0]
|
||||
self.assertTrue(B.allclose(A.sum(1)))
|
||||
|
||||
def test_gemm(self):
|
||||
N = 16
|
||||
a = Tensor.randn(N, N)
|
||||
|
||||
@@ -4,7 +4,7 @@ from tinygrad.helpers import DEVECTORIZE, TRANSCENDENTAL, SPEC
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat
|
||||
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.helpers import panic
|
||||
|
||||
# import all pattern matchers here
|
||||
@@ -19,13 +19,19 @@ from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_s
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops
|
||||
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
|
||||
|
||||
pm_syntactic_sugar = PatternMatcher([
|
||||
# INDEX on ptr INDEX concats them
|
||||
(UPat(Ops.INDEX, name="i1").f(Ops.INDEX, name="i2", allow_any_len=True),
|
||||
lambda i1,i2: i2.replace(src=i1.src+i2.src[1:]) if isinstance(i1.dtype, PtrDType) and not isinstance(i2.dtype, PtrDType) else None),
|
||||
])
|
||||
|
||||
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
|
||||
if ren is None: ren = Renderer()
|
||||
|
||||
if SPEC: type_verify(sink, kernel_spec)
|
||||
|
||||
# preprocess
|
||||
sink = graph_rewrite(sink, pm_mops, name="early movement ops")
|
||||
sink = graph_rewrite(sink, pm_mops+pm_syntactic_sugar, name="early movement ops", bottom_up=True)
|
||||
|
||||
# first we optimize
|
||||
if optimize:
|
||||
|
||||
@@ -1435,11 +1435,6 @@ class Tensor(MathMixin, MovementMixin):
|
||||
final_shape = [r*s for r,s in zip(repeats, base_shape)]
|
||||
return self.reshape(unsqueezed_shape).expand(expanded_shape).reshape(final_shape)
|
||||
|
||||
def _resolve_dim(self, dim:int, *, extra:bool=False) -> int:
|
||||
total = self.ndim + int(extra)
|
||||
if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}")
|
||||
return dim + total if dim < 0 else dim
|
||||
|
||||
def split(self, sizes:int|Sequence[int], dim:int=0) -> tuple[Tensor, ...]:
|
||||
"""
|
||||
Splits the tensor into chunks along the dimension specified by `dim`.
|
||||
@@ -1597,22 +1592,6 @@ class Tensor(MathMixin, MovementMixin):
|
||||
order[dim0], order[dim1] = order[dim1], order[dim0]
|
||||
return self.permute(order)
|
||||
|
||||
def flatten(self, start_dim=0, end_dim=-1) -> Tensor:
|
||||
"""
|
||||
Flattens the tensor by reshaping it into a one-dimensional tensor.
|
||||
If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(8).reshape(2, 2, 2)
|
||||
print(t.flatten().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.flatten(start_dim=1).numpy())
|
||||
```
|
||||
"""
|
||||
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
|
||||
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
|
||||
|
||||
def unflatten(self, dim:int, sizes:tuple[int,...]) -> Tensor:
|
||||
"""
|
||||
Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
|
||||
@@ -1927,6 +1906,12 @@ class Tensor(MathMixin, MovementMixin):
|
||||
is_nan_close = (self.isnan() & other.isnan()) & equal_nan
|
||||
return is_finite_close | is_infinite_close | is_nan_close
|
||||
|
||||
def allclose(self, other:Tensor, rtol:float=1e-05, atol:float=1e-08, equal_nan=False) -> bool:
|
||||
"""
|
||||
Check if all self and other are close. Return True or False.
|
||||
"""
|
||||
return bool(self.isclose(other, rtol=rtol, atol=atol, equal_nan=equal_nan).all().item())
|
||||
|
||||
def mean(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
|
||||
"""
|
||||
Returns the mean value of the tensor along the specified axis or axes.
|
||||
@@ -4194,29 +4179,6 @@ class Tensor(MathMixin, MovementMixin):
|
||||
|
||||
# ***** Tensor Properties *****
|
||||
|
||||
@property
|
||||
def ndim(self) -> int:
|
||||
"""
|
||||
Returns the number of dimensions in the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([[1, 2], [3, 4]])
|
||||
print(t.ndim)
|
||||
```
|
||||
"""
|
||||
return len(self.shape)
|
||||
|
||||
def numel(self) -> sint:
|
||||
"""
|
||||
Returns the total number of elements in the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
|
||||
print(t.numel())
|
||||
```
|
||||
"""
|
||||
return prod(self.shape)
|
||||
|
||||
def element_size(self) -> int:
|
||||
"""
|
||||
Returns the size in bytes of an individual element in the tensor.
|
||||
|
||||
@@ -183,6 +183,34 @@ class MovementMixin:
|
||||
def shape(self) -> tuple["sint", ...]: raise NotImplementedError
|
||||
|
||||
# great functions you get!
|
||||
@property
|
||||
def ndim(self) -> int:
|
||||
"""
|
||||
Returns the number of dimensions in the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([[1, 2], [3, 4]])
|
||||
print(t.ndim)
|
||||
```
|
||||
"""
|
||||
return len(self.shape)
|
||||
|
||||
def numel(self) -> "sint":
|
||||
"""
|
||||
Returns the total number of elements in the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
|
||||
print(t.numel())
|
||||
```
|
||||
"""
|
||||
return prod(self.shape)
|
||||
|
||||
def _resolve_dim(self, dim:int, *, extra:bool=False) -> int:
|
||||
total = self.ndim + int(extra)
|
||||
if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}")
|
||||
return dim + total if dim < 0 else dim
|
||||
|
||||
def view(self, shape, *args) -> Self:
|
||||
"""`.view` is an alias for `.reshape`."""
|
||||
return self.reshape(shape, *args)
|
||||
@@ -204,3 +232,19 @@ class MovementMixin:
|
||||
if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
|
||||
if prod(self.shape) != prod(new_shape): raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})")
|
||||
return self._mop(Ops.RESHAPE, arg=new_shape) if new_shape != self.shape else self
|
||||
|
||||
def flatten(self, start_dim=0, end_dim=-1) -> Self:
|
||||
"""
|
||||
Flattens the tensor by reshaping it into a one-dimensional tensor.
|
||||
If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(8).reshape(2, 2, 2)
|
||||
print(t.flatten().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.flatten(start_dim=1).numpy())
|
||||
```
|
||||
"""
|
||||
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
|
||||
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
|
||||
@@ -187,10 +187,18 @@ class UOp(MathMixin, MovementMixin, metaclass=UOpMetaClass):
|
||||
def _shape(self) -> tuple[sint, ...]|None:
|
||||
match self.op:
|
||||
# late ops don't have shape
|
||||
case Ops.UNIQUE | Ops.DEVICE | Ops.RANGE | Ops.INDEX | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
case Ops.UNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.PRECAST | Ops.CONTRACT:
|
||||
return None
|
||||
|
||||
case Ops.INDEX:
|
||||
# non pointer index doesn't have a shape
|
||||
if not isinstance(self.dtype, PtrDType): return None
|
||||
# fully indexed doesn't have a shape. TODO: remove this
|
||||
if len(self.src[1:]) == len(self.src[0].shape): return None
|
||||
# pointer index
|
||||
return self.src[0].shape[len(self.src[1:]):]
|
||||
|
||||
# some ops init the shape
|
||||
case Ops.CONST | Ops.DEFINE_VAR | Ops.BIND: return () if self._device is not None else None
|
||||
case Ops.BUFFER: return (self.arg,)
|
||||
@@ -344,7 +352,13 @@ class UOp(MathMixin, MovementMixin, metaclass=UOpMetaClass):
|
||||
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
|
||||
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def __getitem__(self, idx):
|
||||
return self.index(*[UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in argfix(idx)])
|
||||
idx = argfix(idx)
|
||||
assert len(idx) == len(self.shape), f"__getitem__ shape mismatch, indexing {self.shape} with {len(idx)} args"
|
||||
if len(slice_idx:=[i for i,x in enumerate(idx) if isinstance(x, slice)]):
|
||||
perm = self.permute(tuple([i for i in range(self.ndim) if i not in slice_idx] + slice_idx))
|
||||
return perm.index(*[UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in idx if not isinstance(x, slice)], ptr=True)
|
||||
else:
|
||||
return self.index(*[UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in idx])
|
||||
def const_like(self, b:ConstLike):
|
||||
# constants can optionally have a DEVICE source
|
||||
return UOp.const(self.dtype, b, device=self._device, shape=self._shape)
|
||||
|
||||
Reference in New Issue
Block a user