slicing + allclose

This commit is contained in:
George Hotz
2025-11-03 12:00:45 +08:00
parent a317d6e625
commit c9a1e35b1e
5 changed files with 98 additions and 53 deletions

View File

@@ -1,5 +1,6 @@
import unittest
from tinygrad import Tensor, UOp, Context
from tinygrad.dtype import AddrSpace
from tinygrad.uop.ops import KernelInfo, AxisType
# **** kernels ****
@@ -9,15 +10,18 @@ def custom_arange_kernel(C:UOp) -> UOp:
return C[i].store(i.cast(C.dtype.base)).end(i).sink(arg=KernelInfo(name=f"custom_arange_{C.size}"))
def custom_add_one_kernel(B:UOp, A:UOp) -> UOp:
A,B = A.flatten(), B.flatten()
assert B.size == A.size
i = UOp.range(A.size, 0)
return B[i].store(A[i] + 1).end(i).sink(arg=KernelInfo(name=f"add_one_{A.size}"))
def custom_elementwise_add_kernel(C:UOp, A:UOp, B:UOp) -> UOp:
C,A,B = C.flatten(), A.flatten(), B.flatten()
i = UOp.range(C.size, 0)
return C[i].store(A[i]+B[i]).end(i).sink(arg=KernelInfo(name=f"custom_add_kernel_{C.size}")).simplify()
def custom_elementwise_addmul_kernel(C:UOp, D:UOp, A:UOp, B:UOp) -> UOp:
C,D,A,B = C.flatten(), D.flatten(), A.flatten(), B.flatten()
assert C.size == D.size
i = UOp.range(C.size, 0)
store_c = C[i].store(A[i]+B[i])
@@ -39,13 +43,22 @@ def custom_sum(B:UOp, A:UOp) -> UOp:
return B.sink(arg=KernelInfo(name=f"custom_sum_{A.shape[0]}", opts_to_apply=()))
def flip_contract_kernel(dest:UOp, src:UOp):
assert dest.size%4 == 0
i = UOp.range(dest.size//4, 0)
j = UOp.range(4, 1, AxisType.UPCAST)
vec = src[i*4+j].contract(j)
store = UOp.group(*[dest[i*4+k].store(vec.gep(3-k)) for k in range(4)])
i = UOp.range(dest.shape[0], 0)
j = UOp.range(dest.shape[1], 1, AxisType.UPCAST)
vec = src[i, j].contract(j)
store = UOp.group(*[dest[i, k].store(vec.gep(3-k)) for k in range(4)])
return store.end(i).sink(arg=KernelInfo(name=f"flip_contract_{dest.size}", opts_to_apply=()))
def slice_sum_kernel(dest:UOp, src:UOp):
G = UOp.range(src.shape[0], 0, AxisType.GLOBAL)
slice_src = src[G, :]
reg = UOp.placeholder((1,), dest.dtype.base, 0, addrspace=AddrSpace.REG)
reg = reg.after(G)[0].set(0)
R = UOp.range(src.shape[1], 1, AxisType.REDUCE)
reg = reg[0].set(reg[0] + slice_src[R], end=R)
ast = dest[G].set(reg[0], end=G)
return ast.sink(arg=KernelInfo(name=f"slice_sum_{src.shape[0]}_{src.shape[1]}", opts_to_apply=()))
# **** backward callbacks ****
def backward_gemm(gradient:UOp, kernel:UOp) -> tuple[UOp, UOp]:
@@ -111,6 +124,12 @@ class TestCustomKernel(unittest.TestCase):
b = Tensor.custom_kernel(tst, a, fxn=custom_sum)[0]
self.assertEqual(b.item(), 15)
def test_slice_sum(self):
A = Tensor.randn(16, 16)
B = Tensor.empty(16)
B = Tensor.custom_kernel(B, A, fxn=slice_sum_kernel)[0]
self.assertTrue(B.allclose(A.sum(1)))
def test_gemm(self):
N = 16
a = Tensor.randn(N, N)

View File

@@ -4,7 +4,7 @@ from tinygrad.helpers import DEVECTORIZE, TRANSCENDENTAL, SPEC
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
from tinygrad.renderer import Renderer
from tinygrad.dtype import dtypes
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.helpers import panic
# import all pattern matchers here
@@ -19,13 +19,19 @@ from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_s
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
pm_syntactic_sugar = PatternMatcher([
# INDEX on ptr INDEX concats them
(UPat(Ops.INDEX, name="i1").f(Ops.INDEX, name="i2", allow_any_len=True),
lambda i1,i2: i2.replace(src=i1.src+i2.src[1:]) if isinstance(i1.dtype, PtrDType) and not isinstance(i2.dtype, PtrDType) else None),
])
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
if ren is None: ren = Renderer()
if SPEC: type_verify(sink, kernel_spec)
# preprocess
sink = graph_rewrite(sink, pm_mops, name="early movement ops")
sink = graph_rewrite(sink, pm_mops+pm_syntactic_sugar, name="early movement ops", bottom_up=True)
# first we optimize
if optimize:

View File

@@ -1435,11 +1435,6 @@ class Tensor(MathMixin, MovementMixin):
final_shape = [r*s for r,s in zip(repeats, base_shape)]
return self.reshape(unsqueezed_shape).expand(expanded_shape).reshape(final_shape)
def _resolve_dim(self, dim:int, *, extra:bool=False) -> int:
total = self.ndim + int(extra)
if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}")
return dim + total if dim < 0 else dim
def split(self, sizes:int|Sequence[int], dim:int=0) -> tuple[Tensor, ...]:
"""
Splits the tensor into chunks along the dimension specified by `dim`.
@@ -1597,22 +1592,6 @@ class Tensor(MathMixin, MovementMixin):
order[dim0], order[dim1] = order[dim1], order[dim0]
return self.permute(order)
def flatten(self, start_dim=0, end_dim=-1) -> Tensor:
"""
Flattens the tensor by reshaping it into a one-dimensional tensor.
If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(8).reshape(2, 2, 2)
print(t.flatten().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.flatten(start_dim=1).numpy())
```
"""
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
def unflatten(self, dim:int, sizes:tuple[int,...]) -> Tensor:
"""
Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
@@ -1927,6 +1906,12 @@ class Tensor(MathMixin, MovementMixin):
is_nan_close = (self.isnan() & other.isnan()) & equal_nan
return is_finite_close | is_infinite_close | is_nan_close
def allclose(self, other:Tensor, rtol:float=1e-05, atol:float=1e-08, equal_nan=False) -> bool:
"""
Check if all self and other are close. Return True or False.
"""
return bool(self.isclose(other, rtol=rtol, atol=atol, equal_nan=equal_nan).all().item())
def mean(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
"""
Returns the mean value of the tensor along the specified axis or axes.
@@ -4194,29 +4179,6 @@ class Tensor(MathMixin, MovementMixin):
# ***** Tensor Properties *****
@property
def ndim(self) -> int:
"""
Returns the number of dimensions in the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 2], [3, 4]])
print(t.ndim)
```
"""
return len(self.shape)
def numel(self) -> sint:
"""
Returns the total number of elements in the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
print(t.numel())
```
"""
return prod(self.shape)
def element_size(self) -> int:
"""
Returns the size in bytes of an individual element in the tensor.

View File

@@ -183,6 +183,34 @@ class MovementMixin:
def shape(self) -> tuple["sint", ...]: raise NotImplementedError
# great functions you get!
@property
def ndim(self) -> int:
"""
Returns the number of dimensions in the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 2], [3, 4]])
print(t.ndim)
```
"""
return len(self.shape)
def numel(self) -> "sint":
"""
Returns the total number of elements in the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
print(t.numel())
```
"""
return prod(self.shape)
def _resolve_dim(self, dim:int, *, extra:bool=False) -> int:
total = self.ndim + int(extra)
if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}")
return dim + total if dim < 0 else dim
def view(self, shape, *args) -> Self:
"""`.view` is an alias for `.reshape`."""
return self.reshape(shape, *args)
@@ -204,3 +232,19 @@ class MovementMixin:
if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
if prod(self.shape) != prod(new_shape): raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})")
return self._mop(Ops.RESHAPE, arg=new_shape) if new_shape != self.shape else self
def flatten(self, start_dim=0, end_dim=-1) -> Self:
"""
Flattens the tensor by reshaping it into a one-dimensional tensor.
If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(8).reshape(2, 2, 2)
print(t.flatten().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.flatten(start_dim=1).numpy())
```
"""
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])

View File

@@ -187,10 +187,18 @@ class UOp(MathMixin, MovementMixin, metaclass=UOpMetaClass):
def _shape(self) -> tuple[sint, ...]|None:
match self.op:
# late ops don't have shape
case Ops.UNIQUE | Ops.DEVICE | Ops.RANGE | Ops.INDEX | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
case Ops.UNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.PRECAST | Ops.CONTRACT:
return None
case Ops.INDEX:
# non pointer index doesn't have a shape
if not isinstance(self.dtype, PtrDType): return None
# fully indexed doesn't have a shape. TODO: remove this
if len(self.src[1:]) == len(self.src[0].shape): return None
# pointer index
return self.src[0].shape[len(self.src[1:]):]
# some ops init the shape
case Ops.CONST | Ops.DEFINE_VAR | Ops.BIND: return () if self._device is not None else None
case Ops.BUFFER: return (self.arg,)
@@ -344,7 +352,13 @@ class UOp(MathMixin, MovementMixin, metaclass=UOpMetaClass):
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def __getitem__(self, idx):
return self.index(*[UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in argfix(idx)])
idx = argfix(idx)
assert len(idx) == len(self.shape), f"__getitem__ shape mismatch, indexing {self.shape} with {len(idx)} args"
if len(slice_idx:=[i for i,x in enumerate(idx) if isinstance(x, slice)]):
perm = self.permute(tuple([i for i in range(self.ndim) if i not in slice_idx] + slice_idx))
return perm.index(*[UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in idx if not isinstance(x, slice)], ptr=True)
else:
return self.index(*[UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in idx])
def const_like(self, b:ConstLike):
# constants can optionally have a DEVICE source
return UOp.const(self.dtype, b, device=self._device, shape=self._shape)